File: translation.py

package info (click to toggle)
brian 2.9.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,872 kB
  • sloc: python: 51,820; cpp: 2,033; makefile: 108; sh: 72
file content (433 lines) | stat: -rw-r--r-- 16,878 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
"""
This module translates a series of statements into a language-specific
syntactically correct code block that can be inserted into a template.

It infers whether or not a variable can be declared as
constant, etc. It should handle common subexpressions, and so forth.

The input information needed:

* The sequence of statements (a multiline string) in standard mathematical form
* The list of known variables, common subexpressions and functions, and for each
  variable whether or not it is a value or an array, and if an array what the
  dtype is.
* The dtype to use for newly created variables
* The language to translate to
"""

import re
from collections.abc import Mapping

import numpy as np
import sympy

from brian2.core.functions import Function
from brian2.core.preferences import prefs
from brian2.core.variables import AuxiliaryVariable, Subexpression, Variable
from brian2.parsing.bast import brian_ast
from brian2.parsing.statements import parse_statement
from brian2.parsing.sympytools import str_to_sympy, sympy_to_str
from brian2.utils.caching import cached
from brian2.utils.stringtools import deindent, get_identifiers, strip_empty_lines
from brian2.utils.topsort import topsort

from .optimisation import optimise_statements
from .statements import Statement

__all__ = ["analyse_identifiers", "get_identifiers_recursively"]


class LineInfo:
    """
    A helper class, just used to store attributes.
    """

    def __init__(self, **kwds):
        for k, v in kwds.items():
            setattr(self, k, v)

    # TODO: This information should go somewhere else, I guess


STANDARD_IDENTIFIERS = {"and", "or", "not", "True", "False"}


def analyse_identifiers(code, variables, recursive=False):
    """
    Analyses a code string (sequence of statements) to find all identifiers by type.

    In a given code block, some variable names (identifiers) must be given as inputs to the code
    block, and some are created by the code block. For example, the line::

        a = b+c

    This could mean to create a new variable a from b and c, or it could mean modify the existing
    value of a from b or c, depending on whether a was previously known.

    Parameters
    ----------
    code : str
        The code string, a sequence of statements one per line.
    variables : dict of `Variable`, set of names
        Specifiers for the model variables or a set of known names
    recursive : bool, optional
        Whether to recurse down into subexpressions (defaults to ``False``).

    Returns
    -------
    newly_defined : set
        A set of variables that are created by the code block.
    used_known : set
        A set of variables that are used and already known, a subset of the
        ``known`` parameter.
    unknown : set
        A set of variables which are used by the code block but not defined by
        it and not previously known. Should correspond to variables in the
        external namespace.
    """
    if isinstance(variables, Mapping):
        known = {
            k for k, v in variables.items() if not isinstance(k, AuxiliaryVariable)
        }
    else:
        known = set(variables)
        variables = {k: Variable(name=k, dtype=np.float64) for k in known}

    known |= STANDARD_IDENTIFIERS
    scalar_stmts, vector_stmts = make_statements(
        code, variables, np.float64, optimise=False
    )
    stmts = scalar_stmts + vector_stmts
    defined = {stmt.var for stmt in stmts if stmt.op == ":="}
    if len(stmts) == 0:
        allids = set()
    elif recursive:
        if not isinstance(variables, Mapping):
            raise TypeError("Have to specify a variables dictionary.")
        allids = get_identifiers_recursively(
            [stmt.expr for stmt in stmts], variables
        ) | {stmt.var for stmt in stmts}
    else:
        allids = set.union(*[get_identifiers(stmt.expr) for stmt in stmts]) | {
            stmt.var for stmt in stmts
        }
    dependent = allids.difference(defined, known)
    used_known = allids.intersection(known) - STANDARD_IDENTIFIERS
    return defined, used_known, dependent


def get_identifiers_recursively(expressions, variables, include_numbers=False):
    """
    Gets all the identifiers in a list of expressions, recursing down into
    subexpressions.

    Parameters
    ----------
    expressions : list of str
        List of expressions to check.
    variables : dict-like
        Dictionary of `Variable` objects
    include_numbers : bool, optional
        Whether to include number literals in the output. Defaults to ``False``.
    """
    if len(expressions):
        identifiers = set.union(
            *[
                get_identifiers(expr, include_numbers=include_numbers)
                for expr in expressions
            ]
        )
    else:
        identifiers = set()
    for name in set(identifiers):
        if name in variables and isinstance(variables[name], Subexpression):
            s_identifiers = get_identifiers_recursively(
                [variables[name].expr], variables, include_numbers=include_numbers
            )
            identifiers |= s_identifiers
    return identifiers


def is_scalar_expression(expr, variables):
    """
    Whether the given expression is scalar.

    Parameters
    ----------
    expr : str
        The expression to check
    variables : dict-like
        `Variable` and `Function` object for all the identifiers used in `expr`

    Returns
    -------
    scalar : bool
        Whether `expr` is a scalar expression
    """
    # determine whether this is a scalar variable
    identifiers = get_identifiers_recursively([expr], variables)
    # In the following we assume that all unknown identifiers are
    # scalar constants -- this should cover numerical literals and
    # e.g. "True" or "inf".
    return all(
        name not in variables
        or getattr(variables[name], "scalar", False)
        or (isinstance(variables[name], Function) and variables[name].stateless)
        for name in identifiers
    )


@cached
def make_statements(code, variables, dtype, optimise=True, blockname=""):
    """
    make_statements(code, variables, dtype, optimise=True, blockname='')

    Turn a series of abstract code statements into Statement objects, inferring
    whether each line is a set/declare operation, whether the variables are
    constant or not, and handling the cacheing of subexpressions.

    Parameters
    ----------
    code : str
        A (multi-line) string of statements.
    variables : dict-like
        A dictionary of with `Variable` and `Function` objects for every
        identifier used in the `code`.
    dtype : `dtype`
        The data type to use for temporary variables
    optimise : bool, optional
        Whether to optimise expressions, including
        pulling out loop invariant expressions and putting them in new
        scalar constants. Defaults to ``False``, since this function is also
        used just to in contexts where we are not interested by this kind of
        optimisation. For the main code generation stage, its value is set by
        the `codegen.loop_invariant_optimisations` preference.
    blockname : str, optional
        A name for the block (used to name intermediate variables to avoid
        name clashes when multiple blocks are used together)
    Returns
    -------
    scalar_statements, vector_statements : (list of `Statement`, list of `Statement`)
        Lists with statements that are to be executed once and statements that
        are to be executed once for every neuron/synapse/... (or in a vectorised
        way)

    Notes
    -----
    If ``optimise`` is ``True``, then the
    ``scalar_statements`` may include newly introduced scalar constants that
    have been identified as loop-invariant and have therefore been pulled out
    of the vector statements. The resulting statements will also use augmented
    assignments where possible, i.e. a statement such as ``w = w + 1`` will be
    replaced by ``w += 1``. Also, statements involving booleans will have
    additional information added to them (see `Statement` for details)
    describing how the statement can be reformulated as a sequence of if/then
    statements. Calls `~brian2.codegen.optimisation.optimise_statements`.
    """
    code = strip_empty_lines(deindent(code))
    lines = re.split(r"[;\n]", code)
    lines = [LineInfo(code=line) for line in lines if len(line)]
    # Do a copy so we can add stuff without altering the original dict
    variables = dict(variables)
    # we will do inference to work out which lines are := and which are =
    defined = {k for k, v in variables.items() if not isinstance(v, AuxiliaryVariable)}
    for line in lines:
        statement = None
        # parse statement into "var op expr"
        var, op, expr, comment = parse_statement(line.code)
        if var in variables and isinstance(variables[var], Subexpression):
            raise SyntaxError(
                f"Illegal line '{line.code}' in abstract code. Cannot write to"
                f" subexpression '{var}'."
            )
        if op == "=":
            if var not in defined:
                op = ":="
                defined.add(var)
                if var not in variables:
                    annotated_ast = brian_ast(expr, variables)
                    is_scalar = annotated_ast.scalar
                    if annotated_ast.dtype == "boolean":
                        use_dtype = bool
                    elif annotated_ast.dtype == "integer":
                        use_dtype = int
                    else:
                        use_dtype = dtype
                    new_var = AuxiliaryVariable(var, dtype=use_dtype, scalar=is_scalar)
                    variables[var] = new_var
            elif not variables[var].is_boolean:
                sympy_expr = str_to_sympy(expr, variables)
                if variables[var].is_integer:
                    sympy_var = sympy.Symbol(var, integer=True)
                else:
                    sympy_var = sympy.Symbol(var, real=True)
                try:
                    collected = sympy.collect(
                        sympy_expr, sympy_var, exact=True, evaluate=False
                    )
                except AttributeError:
                    # If something goes wrong during collection, e.g. collect
                    # does not work for logical expressions
                    collected = {1: sympy_expr}

                if (
                    len(collected) == 2
                    and set(collected.keys()) == {1, sympy_var}
                    and collected[sympy_var] == 1
                ):
                    # We can replace this statement by a += assignment
                    statement = Statement(
                        var,
                        "+=",
                        sympy_to_str(collected[1]),
                        comment,
                        dtype=variables[var].dtype,
                        scalar=variables[var].scalar,
                    )
                elif len(collected) == 1 and sympy_var in collected:
                    # We can replace this statement by a *= assignment
                    statement = Statement(
                        var,
                        "*=",
                        sympy_to_str(collected[sympy_var]),
                        comment,
                        dtype=variables[var].dtype,
                        scalar=variables[var].scalar,
                    )
        if statement is None:
            statement = Statement(
                var,
                op,
                expr,
                comment,
                dtype=variables[var].dtype,
                scalar=variables[var].scalar,
            )

        line.statement = statement
        # for each line will give the variable being written to
        line.write = var
        # each line will give a set of variables which are read
        line.read = get_identifiers_recursively([expr], variables)

    # All writes to scalar variables must happen before writes to vector
    # variables
    scalar_write_done = False
    for line in lines:
        stmt = line.statement
        if stmt.op != ":=" and variables[stmt.var].scalar and scalar_write_done:
            raise SyntaxError(
                "All writes to scalar variables in a code block "
                "have to be made before writes to vector "
                f"variables. Illegal write to '{line.write}'."
            )
        elif not variables[stmt.var].scalar:
            scalar_write_done = True

    # backwards compute whether or not variables will be read again
    # note that will_read for a line gives the set of variables it will read
    # on the current line or subsequent ones. will_write gives the set of
    # variables that will be written after the current line
    will_read = set()
    will_write = set()
    for line in lines[::-1]:
        will_read = will_read.union(line.read)
        line.will_read = will_read.copy()
        line.will_write = will_write.copy()
        will_write.add(line.write)

    subexpressions = {
        name: val for name, val in variables.items() if isinstance(val, Subexpression)
    }
    # Check that no scalar subexpression refers to a vectorised function
    # (e.g. rand()) -- otherwise it would be differently interpreted depending
    # on whether it is used in a scalar or a vector context (i.e., even though
    # the subexpression is supposed to be scalar, it would be vectorised when
    # used as part of non-scalar expressions)
    for name, subexpr in subexpressions.items():
        if subexpr.scalar:
            identifiers = get_identifiers(subexpr.expr)
            for identifier in identifiers:
                if identifier in variables and getattr(
                    variables[identifier], "auto_vectorise", False
                ):
                    raise SyntaxError(
                        f"The scalar subexpression '{name}' refers to "
                        f"the implicitly vectorised function '{identifier}' "
                        "-- this is not allowed since it leads "
                        "to different interpretations of this "
                        "subexpression depending on whether it "
                        "is used in a scalar or vector "
                        "context."
                    )

    # sort subexpressions into an order so that subexpressions that don't depend
    # on other subexpressions are first
    subexpr_deps = {
        name: [dep for dep in subexpr.identifiers if dep in subexpressions]
        for name, subexpr in subexpressions.items()
    }
    sorted_subexpr_vars = topsort(subexpr_deps)

    statements = []

    # none are yet defined (or declared)
    subdefined = {name: None for name in subexpressions}
    for line in lines:
        # update/define all subexpressions needed by this statement
        for var in sorted_subexpr_vars:
            if var not in line.read:
                continue

            subexpression = subexpressions[var]
            # if already defined/declared
            if subdefined[var] == "constant":
                continue
            elif subdefined[var] == "variable":
                op = "="
                constant = False
            else:
                op = ":="
                # check if the referred variables ever change
                ids = subexpression.identifiers
                constant = all(v not in line.will_write for v in ids)
                subdefined[var] = "constant" if constant else "variable"

            statement = Statement(
                var,
                op,
                subexpression.expr,
                comment="",
                dtype=variables[var].dtype,
                constant=constant,
                subexpression=True,
                scalar=variables[var].scalar,
            )
            statements.append(statement)

        stmt = line.statement
        var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment

        # constant only if we are declaring a new variable and we will not
        # write to it again
        constant = op == ":=" and var not in line.will_write
        statement = Statement(
            var,
            op,
            expr,
            comment,
            dtype=variables[var].dtype,
            constant=constant,
            scalar=variables[var].scalar,
        )
        statements.append(statement)

    scalar_statements = [s for s in statements if s.scalar]
    vector_statements = [s for s in statements if not s.scalar]

    if optimise and prefs.codegen.loop_invariant_optimisations:
        scalar_statements, vector_statements = optimise_statements(
            scalar_statements, vector_statements, variables, blockname=blockname
        )

    return scalar_statements, vector_statements