1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
|
"""
This module translates a series of statements into a language-specific
syntactically correct code block that can be inserted into a template.
It infers whether or not a variable can be declared as
constant, etc. It should handle common subexpressions, and so forth.
The input information needed:
* The sequence of statements (a multiline string) in standard mathematical form
* The list of known variables, common subexpressions and functions, and for each
variable whether or not it is a value or an array, and if an array what the
dtype is.
* The dtype to use for newly created variables
* The language to translate to
"""
import re
from collections.abc import Mapping
import numpy as np
import sympy
from brian2.core.functions import Function
from brian2.core.preferences import prefs
from brian2.core.variables import AuxiliaryVariable, Subexpression, Variable
from brian2.parsing.bast import brian_ast
from brian2.parsing.statements import parse_statement
from brian2.parsing.sympytools import str_to_sympy, sympy_to_str
from brian2.utils.caching import cached
from brian2.utils.stringtools import deindent, get_identifiers, strip_empty_lines
from brian2.utils.topsort import topsort
from .optimisation import optimise_statements
from .statements import Statement
__all__ = ["analyse_identifiers", "get_identifiers_recursively"]
class LineInfo:
"""
A helper class, just used to store attributes.
"""
def __init__(self, **kwds):
for k, v in kwds.items():
setattr(self, k, v)
# TODO: This information should go somewhere else, I guess
STANDARD_IDENTIFIERS = {"and", "or", "not", "True", "False"}
def analyse_identifiers(code, variables, recursive=False):
"""
Analyses a code string (sequence of statements) to find all identifiers by type.
In a given code block, some variable names (identifiers) must be given as inputs to the code
block, and some are created by the code block. For example, the line::
a = b+c
This could mean to create a new variable a from b and c, or it could mean modify the existing
value of a from b or c, depending on whether a was previously known.
Parameters
----------
code : str
The code string, a sequence of statements one per line.
variables : dict of `Variable`, set of names
Specifiers for the model variables or a set of known names
recursive : bool, optional
Whether to recurse down into subexpressions (defaults to ``False``).
Returns
-------
newly_defined : set
A set of variables that are created by the code block.
used_known : set
A set of variables that are used and already known, a subset of the
``known`` parameter.
unknown : set
A set of variables which are used by the code block but not defined by
it and not previously known. Should correspond to variables in the
external namespace.
"""
if isinstance(variables, Mapping):
known = {
k for k, v in variables.items() if not isinstance(k, AuxiliaryVariable)
}
else:
known = set(variables)
variables = {k: Variable(name=k, dtype=np.float64) for k in known}
known |= STANDARD_IDENTIFIERS
scalar_stmts, vector_stmts = make_statements(
code, variables, np.float64, optimise=False
)
stmts = scalar_stmts + vector_stmts
defined = {stmt.var for stmt in stmts if stmt.op == ":="}
if len(stmts) == 0:
allids = set()
elif recursive:
if not isinstance(variables, Mapping):
raise TypeError("Have to specify a variables dictionary.")
allids = get_identifiers_recursively(
[stmt.expr for stmt in stmts], variables
) | {stmt.var for stmt in stmts}
else:
allids = set.union(*[get_identifiers(stmt.expr) for stmt in stmts]) | {
stmt.var for stmt in stmts
}
dependent = allids.difference(defined, known)
used_known = allids.intersection(known) - STANDARD_IDENTIFIERS
return defined, used_known, dependent
def get_identifiers_recursively(expressions, variables, include_numbers=False):
"""
Gets all the identifiers in a list of expressions, recursing down into
subexpressions.
Parameters
----------
expressions : list of str
List of expressions to check.
variables : dict-like
Dictionary of `Variable` objects
include_numbers : bool, optional
Whether to include number literals in the output. Defaults to ``False``.
"""
if len(expressions):
identifiers = set.union(
*[
get_identifiers(expr, include_numbers=include_numbers)
for expr in expressions
]
)
else:
identifiers = set()
for name in set(identifiers):
if name in variables and isinstance(variables[name], Subexpression):
s_identifiers = get_identifiers_recursively(
[variables[name].expr], variables, include_numbers=include_numbers
)
identifiers |= s_identifiers
return identifiers
def is_scalar_expression(expr, variables):
"""
Whether the given expression is scalar.
Parameters
----------
expr : str
The expression to check
variables : dict-like
`Variable` and `Function` object for all the identifiers used in `expr`
Returns
-------
scalar : bool
Whether `expr` is a scalar expression
"""
# determine whether this is a scalar variable
identifiers = get_identifiers_recursively([expr], variables)
# In the following we assume that all unknown identifiers are
# scalar constants -- this should cover numerical literals and
# e.g. "True" or "inf".
return all(
name not in variables
or getattr(variables[name], "scalar", False)
or (isinstance(variables[name], Function) and variables[name].stateless)
for name in identifiers
)
@cached
def make_statements(code, variables, dtype, optimise=True, blockname=""):
"""
make_statements(code, variables, dtype, optimise=True, blockname='')
Turn a series of abstract code statements into Statement objects, inferring
whether each line is a set/declare operation, whether the variables are
constant or not, and handling the cacheing of subexpressions.
Parameters
----------
code : str
A (multi-line) string of statements.
variables : dict-like
A dictionary of with `Variable` and `Function` objects for every
identifier used in the `code`.
dtype : `dtype`
The data type to use for temporary variables
optimise : bool, optional
Whether to optimise expressions, including
pulling out loop invariant expressions and putting them in new
scalar constants. Defaults to ``False``, since this function is also
used just to in contexts where we are not interested by this kind of
optimisation. For the main code generation stage, its value is set by
the `codegen.loop_invariant_optimisations` preference.
blockname : str, optional
A name for the block (used to name intermediate variables to avoid
name clashes when multiple blocks are used together)
Returns
-------
scalar_statements, vector_statements : (list of `Statement`, list of `Statement`)
Lists with statements that are to be executed once and statements that
are to be executed once for every neuron/synapse/... (or in a vectorised
way)
Notes
-----
If ``optimise`` is ``True``, then the
``scalar_statements`` may include newly introduced scalar constants that
have been identified as loop-invariant and have therefore been pulled out
of the vector statements. The resulting statements will also use augmented
assignments where possible, i.e. a statement such as ``w = w + 1`` will be
replaced by ``w += 1``. Also, statements involving booleans will have
additional information added to them (see `Statement` for details)
describing how the statement can be reformulated as a sequence of if/then
statements. Calls `~brian2.codegen.optimisation.optimise_statements`.
"""
code = strip_empty_lines(deindent(code))
lines = re.split(r"[;\n]", code)
lines = [LineInfo(code=line) for line in lines if len(line)]
# Do a copy so we can add stuff without altering the original dict
variables = dict(variables)
# we will do inference to work out which lines are := and which are =
defined = {k for k, v in variables.items() if not isinstance(v, AuxiliaryVariable)}
for line in lines:
statement = None
# parse statement into "var op expr"
var, op, expr, comment = parse_statement(line.code)
if var in variables and isinstance(variables[var], Subexpression):
raise SyntaxError(
f"Illegal line '{line.code}' in abstract code. Cannot write to"
f" subexpression '{var}'."
)
if op == "=":
if var not in defined:
op = ":="
defined.add(var)
if var not in variables:
annotated_ast = brian_ast(expr, variables)
is_scalar = annotated_ast.scalar
if annotated_ast.dtype == "boolean":
use_dtype = bool
elif annotated_ast.dtype == "integer":
use_dtype = int
else:
use_dtype = dtype
new_var = AuxiliaryVariable(var, dtype=use_dtype, scalar=is_scalar)
variables[var] = new_var
elif not variables[var].is_boolean:
sympy_expr = str_to_sympy(expr, variables)
if variables[var].is_integer:
sympy_var = sympy.Symbol(var, integer=True)
else:
sympy_var = sympy.Symbol(var, real=True)
try:
collected = sympy.collect(
sympy_expr, sympy_var, exact=True, evaluate=False
)
except AttributeError:
# If something goes wrong during collection, e.g. collect
# does not work for logical expressions
collected = {1: sympy_expr}
if (
len(collected) == 2
and set(collected.keys()) == {1, sympy_var}
and collected[sympy_var] == 1
):
# We can replace this statement by a += assignment
statement = Statement(
var,
"+=",
sympy_to_str(collected[1]),
comment,
dtype=variables[var].dtype,
scalar=variables[var].scalar,
)
elif len(collected) == 1 and sympy_var in collected:
# We can replace this statement by a *= assignment
statement = Statement(
var,
"*=",
sympy_to_str(collected[sympy_var]),
comment,
dtype=variables[var].dtype,
scalar=variables[var].scalar,
)
if statement is None:
statement = Statement(
var,
op,
expr,
comment,
dtype=variables[var].dtype,
scalar=variables[var].scalar,
)
line.statement = statement
# for each line will give the variable being written to
line.write = var
# each line will give a set of variables which are read
line.read = get_identifiers_recursively([expr], variables)
# All writes to scalar variables must happen before writes to vector
# variables
scalar_write_done = False
for line in lines:
stmt = line.statement
if stmt.op != ":=" and variables[stmt.var].scalar and scalar_write_done:
raise SyntaxError(
"All writes to scalar variables in a code block "
"have to be made before writes to vector "
f"variables. Illegal write to '{line.write}'."
)
elif not variables[stmt.var].scalar:
scalar_write_done = True
# backwards compute whether or not variables will be read again
# note that will_read for a line gives the set of variables it will read
# on the current line or subsequent ones. will_write gives the set of
# variables that will be written after the current line
will_read = set()
will_write = set()
for line in lines[::-1]:
will_read = will_read.union(line.read)
line.will_read = will_read.copy()
line.will_write = will_write.copy()
will_write.add(line.write)
subexpressions = {
name: val for name, val in variables.items() if isinstance(val, Subexpression)
}
# Check that no scalar subexpression refers to a vectorised function
# (e.g. rand()) -- otherwise it would be differently interpreted depending
# on whether it is used in a scalar or a vector context (i.e., even though
# the subexpression is supposed to be scalar, it would be vectorised when
# used as part of non-scalar expressions)
for name, subexpr in subexpressions.items():
if subexpr.scalar:
identifiers = get_identifiers(subexpr.expr)
for identifier in identifiers:
if identifier in variables and getattr(
variables[identifier], "auto_vectorise", False
):
raise SyntaxError(
f"The scalar subexpression '{name}' refers to "
f"the implicitly vectorised function '{identifier}' "
"-- this is not allowed since it leads "
"to different interpretations of this "
"subexpression depending on whether it "
"is used in a scalar or vector "
"context."
)
# sort subexpressions into an order so that subexpressions that don't depend
# on other subexpressions are first
subexpr_deps = {
name: [dep for dep in subexpr.identifiers if dep in subexpressions]
for name, subexpr in subexpressions.items()
}
sorted_subexpr_vars = topsort(subexpr_deps)
statements = []
# none are yet defined (or declared)
subdefined = {name: None for name in subexpressions}
for line in lines:
# update/define all subexpressions needed by this statement
for var in sorted_subexpr_vars:
if var not in line.read:
continue
subexpression = subexpressions[var]
# if already defined/declared
if subdefined[var] == "constant":
continue
elif subdefined[var] == "variable":
op = "="
constant = False
else:
op = ":="
# check if the referred variables ever change
ids = subexpression.identifiers
constant = all(v not in line.will_write for v in ids)
subdefined[var] = "constant" if constant else "variable"
statement = Statement(
var,
op,
subexpression.expr,
comment="",
dtype=variables[var].dtype,
constant=constant,
subexpression=True,
scalar=variables[var].scalar,
)
statements.append(statement)
stmt = line.statement
var, op, expr, comment = stmt.var, stmt.op, stmt.expr, stmt.comment
# constant only if we are declaring a new variable and we will not
# write to it again
constant = op == ":=" and var not in line.will_write
statement = Statement(
var,
op,
expr,
comment,
dtype=variables[var].dtype,
constant=constant,
scalar=variables[var].scalar,
)
statements.append(statement)
scalar_statements = [s for s in statements if s.scalar]
vector_statements = [s for s in statements if not s.scalar]
if optimise and prefs.codegen.loop_invariant_optimisations:
scalar_statements, vector_statements = optimise_statements(
scalar_statements, vector_statements, variables, blockname=blockname
)
return scalar_statements, vector_statements
|