1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
|
"""
Module defining `CodeString`, a class for a string of code together with
information about its namespace. Only serves as a parent class, its subclasses
`Expression` and `Statements` are the ones that are actually used.
"""
from collections.abc import Hashable
import sympy
from brian2.parsing.sympytools import str_to_sympy, sympy_to_str
from brian2.utils.logger import get_logger
from brian2.utils.stringtools import get_identifiers
__all__ = ["Expression", "Statements"]
logger = get_logger(__name__)
class CodeString(Hashable):
"""
A class for representing "code strings", i.e. a single Python expression
or a sequence of Python statements.
Parameters
----------
code : str
The code string, may be an expression or a statement(s) (possibly
multi-line).
"""
def __init__(self, code):
self._code = code.strip()
# : Set of identifiers in the code string
self.identifiers = get_identifiers(code)
code = property(lambda self: self._code, doc="The code string")
def __str__(self):
return self.code
def __repr__(self):
return f"{self.__class__.__name__}({self.code!r})"
def __eq__(self, other):
if not isinstance(other, CodeString):
return NotImplemented
return self.code == other.code
def __ne__(self, other):
return not self == other
def __hash__(self):
return hash(self.code)
class Statements(CodeString):
"""
Class for representing statements.
Parameters
----------
code : str
The statement or statements. Several statements can be given as a
multi-line string or separated by semicolons.
Notes
-----
Currently, the implementation of this class does not add anything to
`~brian2.equations.codestrings.CodeString`, but it should be used instead
of that class for clarity and to allow for future functionality that is
only relevant to statements and not to expressions.
"""
pass
class Expression(CodeString):
"""
Class for representing an expression.
Parameters
----------
code : str, optional
The expression. Note that the expression has to be written in a form
that is parseable by sympy. Alternatively, a sympy expression can be
provided (in the ``sympy_expression`` argument).
sympy_expression : sympy expression, optional
A sympy expression. Alternatively, a plain string expression can be
provided (in the ``code`` argument).
"""
def __init__(self, code=None, sympy_expression=None):
if code is None and sympy_expression is None:
raise TypeError("Have to provide either a string or a sympy expression")
if code is not None and sympy_expression is not None:
raise TypeError(
"Provide a string expression or a sympy expression, not both"
)
if code is None:
code = sympy_to_str(sympy_expression)
else:
# Just try to convert it to a sympy expression to get syntax errors
# for incorrect expressions
str_to_sympy(code)
super().__init__(code=code)
stochastic_variables = property(
lambda self: {
variable
for variable in self.identifiers
if variable == "xi" or variable.startswith("xi_")
},
doc="Stochastic variables in this expression",
)
def split_stochastic(self):
"""
Split the expression into a stochastic and non-stochastic part.
Splits the expression into a tuple of one `Expression` objects f (the
non-stochastic part) and a dictionary mapping stochastic variables
to `Expression` objects. For example, an expression of the form
``f + g * xi_1 + h * xi_2`` would be returned as:
``(f, {'xi_1': g, 'xi_2': h})``
Note that the `Expression` objects for the stochastic parts do not
include the stochastic variable itself.
Returns
-------
(f, d) : (`Expression`, dict)
A tuple of an `Expression` object and a dictionary, the first
expression being the non-stochastic part of the equation and
the dictionary mapping stochastic variables (``xi`` or starting
with ``xi_``) to `Expression` objects. If no stochastic variable
is present in the code string, a tuple ``(self, None)`` will be
returned with the unchanged `Expression` object.
"""
stochastic_variables = []
for identifier in self.identifiers:
if identifier == "xi" or identifier.startswith("xi_"):
stochastic_variables.append(identifier)
# No stochastic variable
if not len(stochastic_variables):
return (self, None)
stochastic_symbols = [
sympy.Symbol(variable, real=True) for variable in stochastic_variables
]
# Note that collect only works properly if the expression is expanded
collected = (
str_to_sympy(self.code).expand().collect(stochastic_symbols, evaluate=False)
)
f_expr = None
stochastic_expressions = {}
for var, s_expr in collected.items():
expr = Expression(sympy_expression=s_expr)
if var == 1:
if any(s_expr.has(s) for s in stochastic_symbols):
raise AssertionError(
"Error when separating expression "
f"'{self.code}' into stochastic and non-"
"stochastic term: non-stochastic "
f"part was determined to be '{s_expr}' but "
"contains a stochastic symbol."
)
f_expr = expr
elif var in stochastic_symbols:
stochastic_expressions[str(var)] = expr
else:
raise ValueError(
f"Expression '{self.code}' cannot be separated into "
"stochastic and non-stochastic "
"term"
)
if f_expr is None:
f_expr = Expression("0.0")
return f_expr, stochastic_expressions
def _repr_pretty_(self, p, cycle):
"""
Pretty printing for ipython.
"""
if cycle:
raise AssertionError("Cyclical call of 'CodeString._repr_pretty'")
# Make use of sympy's pretty printing
p.pretty(str_to_sympy(self.code))
def __eq__(self, other):
if not isinstance(other, Expression):
return NotImplemented
return self.code == other.code
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash(self.code)
def is_constant_over_dt(expression, variables, dt_value):
"""
Check whether an expression can be considered as constant over a time step.
This is *not* the case when the expression either:
1. contains the variable ``t`` (except as the argument of a function that
can be considered as constant over a time step, e.g. a `TimedArray` with
a dt equal to or greater than the dt used to evaluate this expression)
2. refers to a stateful function such as ``rand()``.
Parameters
----------
expression : `sympy.Expr`
The (sympy) expression to analyze
variables : dict
The variables dictionary.
dt_value : float or None
The length of a timestep (without units), can be ``None`` if the
time step is not yet known.
Returns
-------
is_constant : bool
Whether the expression can be considered to be constant over a time
step.
"""
t_symbol = sympy.Symbol("t", real=True, positive=True)
if expression == t_symbol:
return False # The full expression is simply "t"
func_name = str(expression.func)
func_variable = variables.get(func_name, None)
if func_variable is not None and not func_variable.stateless:
return False
for arg in expression.args:
if arg == t_symbol and dt_value is not None:
# We found "t" -- if it is not the only argument of a locally
# constant function we bail out
if not (
func_variable is not None
and func_variable.is_locally_constant(dt_value)
):
return False
else:
if not is_constant_over_dt(arg, variables, dt_value):
return False
return True
|