"""
Numerical integration functions.
"""

import operator
import string
from functools import reduce

import sympy
from pyparsing import (
    Group,
    Literal,
    ParseException,
    Suppress,
    Word,
    ZeroOrMore,
    restOfLine,
)
from sympy.core.sympify import SympifyError

from brian2.equations.codestrings import is_constant_over_dt
from brian2.parsing.sympytools import str_to_sympy, sympy_to_str

from .base import (
    StateUpdateMethod,
    UnsupportedEquationsException,
    extract_method_options,
)

__all__ = ["milstein", "heun", "euler", "rk2", "rk4", "ExplicitStateUpdater"]


# ===============================================================================
# Class for simple definition of explicit state updaters
# ===============================================================================


def _symbol(name, positive=None):
    """Shorthand for ``sympy.Symbol(name, real=True)``."""
    return sympy.Symbol(name, real=True, positive=positive)


#: reserved standard symbols
SYMBOLS = {
    "__x": _symbol("__x"),
    "__t": _symbol("__t", positive=True),
    "dt": _symbol("dt", positive=True),
    "t": _symbol("t", positive=True),
    "__f": sympy.Function("__f"),
    "__g": sympy.Function("__g"),
    "__dW": _symbol("__dW"),
}


def split_expression(expr):
    """
    Split an expression into a part containing the function ``f`` and another
    one containing the function ``g``. Returns a tuple of the two expressions
    (as sympy expressions).

    Parameters
    ----------
    expr : str
        An expression containing references to functions ``f`` and ``g``.

    Returns
    -------
    (non_stochastic, stochastic) : tuple of sympy expressions
        A pair of expressions representing the non-stochastic (containing
        function-independent terms and terms involving ``f``) and the
        stochastic part of the expression (terms involving ``g`` and/or ``dW``).

    Examples
    --------
    >>> split_expression('dt * __f(__x, __t)')
    (dt*__f(__x, __t), None)
    >>> split_expression('dt * __f(__x, __t) + __dW * __g(__x, __t)')
    (dt*__f(__x, __t), __dW*__g(__x, __t))
    >>> split_expression('1/(2*sqrt(dt))*(__g_support - __g(__x, __t))*(sqrt(__dW))')
    (0, sqrt(__dW)*__g_support/(2*sqrt(dt)) - sqrt(__dW)*__g(__x, __t)/(2*sqrt(dt)))
    """

    f = SYMBOLS["__f"]
    g = SYMBOLS["__g"]
    dW = SYMBOLS["__dW"]
    # Arguments of the f and g functions
    x_f = sympy.Wild("x_f", exclude=[f, g], real=True)
    t_f = sympy.Wild("t_f", exclude=[f, g], real=True)
    x_g = sympy.Wild("x_g", exclude=[f, g], real=True)
    t_g = sympy.Wild("t_g", exclude=[f, g], real=True)

    # Reorder the expression so that f(x,t) and g(x,t) are factored out
    sympy_expr = sympy.sympify(expr, locals=SYMBOLS).expand()
    sympy_expr = sympy.collect(sympy_expr, f(x_f, t_f))
    sympy_expr = sympy.collect(sympy_expr, g(x_g, t_g))

    # Constant part, contains neither f, g nor dW
    independent = sympy.Wild("independent", exclude=[f, g, dW], real=True)
    # The exponent of the random number
    dW_exponent = sympy.Wild("dW_exponent", exclude=[f, g, dW, 0], real=True)
    # The factor for the random number, not containing the g function
    independent_dW = sympy.Wild("independent_dW", exclude=[f, g, dW], real=True)
    # The factor for the f function
    f_factor = sympy.Wild("f_factor", exclude=[f, g], real=True)
    # The factor for the g function
    g_factor = sympy.Wild("g_factor", exclude=[f, g], real=True)

    match_expr = (
        independent
        + f_factor * f(x_f, t_f)
        + independent_dW * dW**dW_exponent
        + g_factor * g(x_g, t_g)
    )
    matches = sympy_expr.match(match_expr)

    if matches is None:
        raise ValueError(
            f'Expression "{sympy_expr}" in the state updater description could not be'
            " parsed."
        )

    # Non-stochastic part
    if x_f in matches:
        # Includes the f function
        non_stochastic = matches[independent] + (
            matches[f_factor] * f(matches[x_f], matches[t_f])
        )
    else:
        # Does not include f, might be 0
        non_stochastic = matches[independent]

    # Stochastic part
    if independent_dW in matches and matches[independent_dW] != 0:
        # includes a random variable term with a non-zero factor
        stochastic = (
            matches[g_factor] * g(matches[x_g], matches[t_g])
            + matches[independent_dW] * dW ** matches[dW_exponent]
        )
    elif x_g in matches:
        # Does not include a random variable but the g function
        stochastic = matches[g_factor] * g(matches[x_g], matches[t_g])
    else:
        # Contains neither random variable nor g function --> empty
        stochastic = None

    return (non_stochastic, stochastic)


class ExplicitStateUpdater(StateUpdateMethod):
    """
    An object that can be used for defining state updaters via a simple
    description (see below). Resulting instances can be passed to the
    ``method`` argument of the `NeuronGroup` constructor. As other state
    updater functions the `ExplicitStateUpdater` objects are callable,
    returning abstract code when called with an `Equations` object.

    A description of an explicit state updater consists of a (multi-line)
    string, containing assignments to variables and a final "x_new = ...",
    stating the integration result for a single timestep. The assignments
    can be used to define an arbitrary number of intermediate results and
    can refer to ``f(x, t)`` (the function being integrated, as a function of
    ``x``, the previous value of the state variable and ``t``, the time) and
    ``dt``, the size of the timestep.

    For example, to define a Runge-Kutta 4 integrator (already provided as
    `rk4`), use::

            k1 = dt*f(x,t)
            k2 = dt*f(x+k1/2,t+dt/2)
            k3 = dt*f(x+k2/2,t+dt/2)
            k4 = dt*f(x+k3,t+dt)
            x_new = x+(k1+2*k2+2*k3+k4)/6

    Note that for stochastic equations, the function `f` only corresponds to
    the non-stochastic part of the equation. The additional function `g`
    corresponds to the stochastic part that has to be multiplied with the
    stochastic variable xi (a standard normal random variable -- if the
    algorithm needs a random variable with a different variance/mean you have
    to multiply/add it accordingly). Equations with more than one
    stochastic variable do not have to be treated differently, the part
    referring to ``g`` is repeated for all stochastic variables automatically.

    Stochastic integrators can also make reference to ``dW`` (a normal
    distributed random number with variance ``dt``) and ``g(x, t)``, the
    stochastic part of an equation. A stochastic state updater could therefore
    use a description like::

        x_new = x + dt*f(x,t) + g(x, t) * dW

    For simplicity, the same syntax is used for state updaters that only support
    additive noise, even though ``g(x, t)`` does not depend on ``x`` or ``t``
    in that case.

    There a some restrictions on the complexity of the expressions (but most
    can be worked around by using intermediate results as in the above Runge-
    Kutta example): Every statement can only contain the functions ``f`` and
    ``g`` once; The expressions have to be linear in the functions, e.g. you
    can use ``dt*f(x, t)`` but not ``f(x, t)**2``.

    Parameters
    ----------
    description : str
        A state updater description (see above).
    stochastic : {None, 'additive', 'multiplicative'}
        What kind of stochastic equations this state updater supports: ``None``
        means no support of stochastic equations, ``'additive'`` means only
        equations with additive noise and ``'multiplicative'`` means
        supporting arbitrary stochastic equations.

    Raises
    ------
    ValueError
        If the parsing of the description failed.

    Notes
    -----
    Since clocks are updated *after* the state update, the time ``t`` used
    in the state update step is still at its previous value. Enumerating the
    states and discrete times, ``x_new = x + dt*f(x, t)`` is therefore
    understood as :math:`x_{i+1} = x_i + dt f(x_i, t_i)`, yielding the correct
    forward Euler integration. If the integrator has to refer to the time at
    the end of the timestep, simply use ``t + dt`` instead of ``t``.

    See also
    --------
    euler, rk2, rk4, milstein
    """

    # ===========================================================================
    # Parsing definitions
    # ===========================================================================
    #: Legal names for temporary variables
    TEMP_VAR = ~Literal("x_new") + Word(
        f"{string.ascii_letters}_", f"{string.ascii_letters + string.digits}_"
    ).setResultsName("identifier")

    #: A single expression
    EXPRESSION = restOfLine.setResultsName("expression")

    #: An assignment statement
    STATEMENT = Group(TEMP_VAR + Suppress("=") + EXPRESSION).setResultsName("statement")

    #: The last line of a state updater description
    OUTPUT = Group(
        Suppress(Literal("x_new")) + Suppress("=") + EXPRESSION
    ).setResultsName("output")

    #: A complete state updater description
    DESCRIPTION = ZeroOrMore(STATEMENT) + OUTPUT

    def __init__(self, description, stochastic=None, custom_check=None):
        self._description = description
        self.stochastic = stochastic
        self.custom_check = custom_check

        try:
            parsed = ExplicitStateUpdater.DESCRIPTION.parseString(
                description, parseAll=True
            )
        except ParseException as p_exc:
            ex = SyntaxError(f"Parsing failed: {str(p_exc.msg)}")
            ex.text = str(p_exc.line)
            ex.offset = p_exc.column
            ex.lineno = p_exc.lineno
            raise ex

        self.statements = []
        self.symbols = SYMBOLS.copy()
        for element in parsed:
            expression = str_to_sympy(element.expression)
            # Replace all symbols used in state updater expressions by unique
            # names that cannot clash with user-defined variables or functions
            expression = expression.subs(sympy.Function("f"), self.symbols["__f"])
            expression = expression.subs(sympy.Function("g"), self.symbols["__g"])
            symbols = list(expression.atoms(sympy.Symbol))
            unique_symbols = []
            for symbol in symbols:
                if symbol.name == "dt":
                    unique_symbols.append(symbol)
                else:
                    unique_symbols.append(_symbol(f"__{symbol.name}"))
            for symbol, unique_symbol in zip(symbols, unique_symbols):
                expression = expression.subs(symbol, unique_symbol)

            self.symbols.update({symbol.name: symbol for symbol in unique_symbols})
            if element.getName() == "statement":
                self.statements.append((f"__{element.identifier}", expression))
            elif element.getName() == "output":
                self.output = expression
            else:
                raise AssertionError(f"Unknown element name: {element.getName()}")

    def __repr__(self):
        # recreate a description string
        description = "\n".join([f"{var} = {expr}" for var, expr in self.statements])
        if len(description):
            description += "\n"
        description += f"x_new = {str(self.output)}"
        classname = self.__class__.__name__
        return f"{classname}('''{description}''', stochastic={self.stochastic!r})"

    def __str__(self):
        s = f"{self.__class__.__name__}\n"

        if len(self.statements) > 0:
            s += "Intermediate statements:\n"
            s += "\n".join(
                [f"{var} = {sympy_to_str(expr)}" for var, expr in self.statements]
            )
            s += "\n"

        s += "Output:\n"
        s += sympy_to_str(self.output)
        return s

    def _latex(self, *args):
        from sympy import Symbol, latex

        s = [r"\begin{equation}"]
        for var, expr in self.statements:
            expr = expr.subs(Symbol("x"), Symbol("x_t"))
            s.append(f"{latex(Symbol(var))} = {latex(expr)}\\\\")
        expr = self.output.subs(Symbol("x"), "x_t")
        s.append(f"x_{{t+1}} = {latex(expr)}")
        s.append(r"\end{equation}")
        return "\n".join(s)

    def _repr_latex_(self):
        return self._latex()

    def replace_func(self, x, t, expr, temp_vars, eq_symbols, stochastic_variable=None):
        """
        Used to replace a single occurance of ``f(x, t)`` or ``g(x, t)``:
        `expr` is the non-stochastic (in the case of ``f``) or stochastic
        part (``g``) of the expression defining the right-hand-side of the
        differential equation describing `var`. It replaces the variable
        `var` with the value given as `x` and `t` by the value given for
        `t`. Intermediate variables will be replaced with the appropriate
        replacements as well.

        For example, in the `rk2` integrator, the second step involves the
        calculation of ``f(k/2 + x, dt/2 + t)``.  If `var` is ``v`` and
        `expr` is ``-v / tau``, this will result in ``-(_k_v/2 + v)/tau``.

        Note that this deals with only one state variable `var`, given as
        an argument to the surrounding `_generate_RHS` function.
        """

        try:
            s_expr = str_to_sympy(str(expr))
        except SympifyError as ex:
            raise ValueError(f'Error parsing the expression "{expr}": {str(ex)}')

        for var in eq_symbols:
            # Generate specific temporary variables for the state variable,
            # e.g. '_k_v' for the state variable 'v' and the temporary
            # variable 'k'.
            if stochastic_variable is None:
                temp_var_replacements = {
                    self.symbols[temp_var]: _symbol(f"{temp_var}_{var}")
                    for temp_var in temp_vars
                }
            else:
                temp_var_replacements = {
                    self.symbols[temp_var]: _symbol(
                        f"{temp_var}_{var}_{stochastic_variable}"
                    )
                    for temp_var in temp_vars
                }
            # In the expression given as 'x', replace 'x' by the variable
            # 'var' and all the temporary variables by their
            # variable-specific counterparts.
            x_replacement = x.subs(self.symbols["__x"], eq_symbols[var])
            x_replacement = x_replacement.subs(temp_var_replacements)

            # Replace the variable `var` in the expression by the new `x`
            # expression
            s_expr = s_expr.subs(eq_symbols[var], x_replacement)

        # If the expression given for t in the state updater description
        # is not just "t" (or rather "__t"), then replace t in the
        # equations by it, and replace "__t" by "t" afterwards.
        if t != self.symbols["__t"]:
            s_expr = s_expr.subs(SYMBOLS["t"], t)
            s_expr = s_expr.replace(self.symbols["__t"], SYMBOLS["t"])

        return s_expr

    def _non_stochastic_part(
        self,
        eq_symbols,
        non_stochastic,
        non_stochastic_expr,
        stochastic_variable,
        temp_vars,
        var,
    ):
        non_stochastic_results = []
        if stochastic_variable is None or len(stochastic_variable) == 0:
            # Replace the f(x, t) part
            replace_f = lambda x, t: self.replace_func(
                x, t, non_stochastic, temp_vars, eq_symbols
            )
            non_stochastic_result = non_stochastic_expr.replace(
                self.symbols["__f"], replace_f
            )
            # Replace x by the respective variable
            non_stochastic_result = non_stochastic_result.subs(
                self.symbols["__x"], eq_symbols[var]
            )
            # Replace intermediate variables
            temp_var_replacements = {
                self.symbols[temp_var]: _symbol(f"{temp_var}_{var}")
                for temp_var in temp_vars
            }
            non_stochastic_result = non_stochastic_result.subs(temp_var_replacements)
            non_stochastic_results.append(non_stochastic_result)
        elif isinstance(stochastic_variable, str):
            # Replace the f(x, t) part
            replace_f = lambda x, t: self.replace_func(
                x, t, non_stochastic, temp_vars, eq_symbols, stochastic_variable
            )
            non_stochastic_result = non_stochastic_expr.replace(
                self.symbols["__f"], replace_f
            )
            # Replace x by the respective variable
            non_stochastic_result = non_stochastic_result.subs(
                self.symbols["__x"], eq_symbols[var]
            )
            # Replace intermediate variables
            temp_var_replacements = {
                self.symbols[temp_var]: _symbol(
                    f"{temp_var}_{var}_{stochastic_variable}"
                )
                for temp_var in temp_vars
            }

            non_stochastic_result = non_stochastic_result.subs(temp_var_replacements)
            non_stochastic_results.append(non_stochastic_result)
        else:
            # Replace the f(x, t) part
            replace_f = lambda x, t: self.replace_func(
                x, t, non_stochastic, temp_vars, eq_symbols
            )
            non_stochastic_result = non_stochastic_expr.replace(
                self.symbols["__f"], replace_f
            )
            # Replace x by the respective variable
            non_stochastic_result = non_stochastic_result.subs(
                self.symbols["__x"], eq_symbols[var]
            )
            # Replace intermediate variables
            temp_var_replacements = {
                self.symbols[temp_var]: reduce(
                    operator.add,
                    [_symbol(f"{temp_var}_{var}_{xi}") for xi in stochastic_variable],
                )
                for temp_var in temp_vars
            }

            non_stochastic_result = non_stochastic_result.subs(temp_var_replacements)
            non_stochastic_results.append(non_stochastic_result)

        return non_stochastic_results

    def _stochastic_part(
        self,
        eq_symbols,
        stochastic,
        stochastic_expr,
        stochastic_variable,
        temp_vars,
        var,
    ):
        stochastic_results = []
        if isinstance(stochastic_variable, str):
            # Replace the g(x, t) part
            replace_f = lambda x, t: self.replace_func(
                x,
                t,
                stochastic.get(stochastic_variable, 0),
                temp_vars,
                eq_symbols,
                stochastic_variable,
            )
            stochastic_result = stochastic_expr.replace(self.symbols["__g"], replace_f)
            # Replace x by the respective variable
            stochastic_result = stochastic_result.subs(
                self.symbols["__x"], eq_symbols[var]
            )
            # Replace dW by the respective variable
            stochastic_result = stochastic_result.subs(
                self.symbols["__dW"], stochastic_variable
            )
            # Replace intermediate variables
            temp_var_replacements = {
                self.symbols[temp_var]: _symbol(
                    f"{temp_var}_{var}_{stochastic_variable}"
                )
                for temp_var in temp_vars
            }

            stochastic_result = stochastic_result.subs(temp_var_replacements)
            stochastic_results.append(stochastic_result)
        else:
            for xi in stochastic_variable:
                # Replace the g(x, t) part
                replace_f = lambda x, t: self.replace_func(
                    x, t, stochastic.get(xi, 0), temp_vars, eq_symbols, xi  # noqa: B023
                )
                stochastic_result = stochastic_expr.replace(
                    self.symbols["__g"], replace_f
                )
                # Replace x by the respective variable
                stochastic_result = stochastic_result.subs(
                    self.symbols["__x"], eq_symbols[var]
                )

                # Replace dW by the respective variable
                stochastic_result = stochastic_result.subs(self.symbols["__dW"], xi)

                # Replace intermediate variables
                temp_var_replacements = {
                    self.symbols[temp_var]: _symbol(f"{temp_var}_{var}_{xi}")
                    for temp_var in temp_vars
                }

                stochastic_result = stochastic_result.subs(temp_var_replacements)
                stochastic_results.append(stochastic_result)
        return stochastic_results

    def _generate_RHS(
        self,
        eqs,
        var,
        eq_symbols,
        temp_vars,
        expr,
        non_stochastic_expr,
        stochastic_expr,
        stochastic_variable=(),
    ):
        """
        Helper function used in `__call__`. Generates the right hand side of
        an abstract code statement by appropriately replacing f, g and t.
        For example, given a differential equation ``dv/dt = -(v + I) / tau``
        (i.e. `var` is ``v` and `expr` is ``(-v + I) / tau``) together with
        the `rk2` step ``return x + dt*f(x +  k/2, t + dt/2)``
        (i.e. `non_stochastic_expr` is
        ``x + dt*f(x +  k/2, t + dt/2)`` and `stochastic_expr` is ``None``),
        produces ``v + dt*(-v - _k_v/2 + I + _k_I/2)/tau``.

        """

        # Note: in the following we are silently ignoring the case that a
        # state updater does not care about either the non-stochastic or the
        # stochastic part of an equation. We do trust state updaters to
        # correctly specify their own abilities (i.e. they do not claim to
        # support stochastic equations but actually just ignore the stochastic
        # part). We can't really check the issue here, as we are only dealing
        # with one line of the state updater description. It is perfectly valid
        # to write the euler update as:
        #     non_stochastic = dt * f(x, t)
        #     stochastic = dt**.5 * g(x, t) * xi
        #     return x + non_stochastic + stochastic
        #
        # In the above case, we'll deal with lines which do not define either
        # the stochastic or the non-stochastic part.

        non_stochastic, stochastic = expr.split_stochastic()

        if non_stochastic_expr is not None:
            # We do have a non-stochastic part in the state updater description
            non_stochastic_results = self._non_stochastic_part(
                eq_symbols,
                non_stochastic,
                non_stochastic_expr,
                stochastic_variable,
                temp_vars,
                var,
            )
        else:
            non_stochastic_results = []

        if not (stochastic is None or stochastic_expr is None):
            # We do have a stochastic part in the state
            # updater description
            stochastic_results = self._stochastic_part(
                eq_symbols,
                stochastic,
                stochastic_expr,
                stochastic_variable,
                temp_vars,
                var,
            )
        else:
            stochastic_results = []

        RHS = sympy.Number(0)
        # All the parts (one non-stochastic and potentially more than one
        # stochastic part) are combined with addition
        for non_stochastic_result in non_stochastic_results:
            RHS += non_stochastic_result
        for stochastic_result in stochastic_results:
            RHS += stochastic_result

        return sympy_to_str(RHS)

    def __call__(self, eqs, variables=None, method_options=None):
        """
        Apply a state updater description to model equations.

        Parameters
        ----------
        eqs : `Equations`
            The equations describing the model
        variables: dict-like, optional
            The `Variable` objects for the model. Ignored by the explicit
            state updater.
        method_options : dict, optional
            Additional options to the state updater (not used at the moment
            for the explicit state updaters).

        Examples
        --------
        >>> from brian2 import *
        >>> eqs = Equations('dv/dt = -v / tau : volt')
        >>> print(euler(eqs))
        _v = -dt*v/tau + v
        v = _v
        >>> print(rk4(eqs))
        __k_1_v = -dt*v/tau
        __k_2_v = -dt*(__k_1_v/2 + v)/tau
        __k_3_v = -dt*(__k_2_v/2 + v)/tau
        __k_4_v = -dt*(__k_3_v + v)/tau
        _v = __k_1_v/6 + __k_2_v/3 + __k_3_v/3 + __k_4_v/6 + v
        v = _v
        """
        extract_method_options(method_options, {})
        # Non-stochastic numerical integrators should work for all equations,
        # except for stochastic equations
        if eqs.is_stochastic and self.stochastic is None:
            raise UnsupportedEquationsException(
                "Cannot integrate stochastic equations with this state updater."
            )
        if self.custom_check:
            self.custom_check(eqs, variables)
        # The final list of statements
        statements = []

        stochastic_variables = eqs.stochastic_variables

        # The variables for the intermediate results in the state updater
        # description, e.g. the variable k in rk2
        intermediate_vars = [var for var, expr in self.statements]

        # A dictionary mapping all the variables in the equations to their
        # sympy representations
        eq_variables = {var: _symbol(var) for var in eqs.eq_names}

        # Generate the random numbers for the stochastic variables
        for stochastic_variable in stochastic_variables:
            statements.append(f"{stochastic_variable} = dt**.5 * randn()")

        substituted_expressions = eqs.get_substituted_expressions(variables)

        # Process the intermediate statements in the stateupdater description
        for intermediate_var, intermediate_expr in self.statements:
            # Split the expression into a non-stochastic and a stochastic part
            non_stochastic_expr, stochastic_expr = split_expression(intermediate_expr)

            # Execute the statement by appropriately replacing the functions f
            # and g and the variable x for every equation in the model.
            # We use the model equations where the subexpressions have
            # already been substituted into the model equations.
            for var, expr in substituted_expressions:
                for xi in stochastic_variables:
                    RHS = self._generate_RHS(
                        eqs,
                        var,
                        eq_variables,
                        intermediate_vars,
                        expr,
                        non_stochastic_expr,
                        stochastic_expr,
                        xi,
                    )
                    statements.append(f"{intermediate_var}_{var}_{xi} = {RHS}")
                if not stochastic_variables:  # no stochastic variables
                    RHS = self._generate_RHS(
                        eqs,
                        var,
                        eq_variables,
                        intermediate_vars,
                        expr,
                        non_stochastic_expr,
                        stochastic_expr,
                    )
                    statements.append(f"{intermediate_var}_{var} = {RHS}")

        # Process the "return" line of the stateupdater description
        non_stochastic_expr, stochastic_expr = split_expression(self.output)

        if eqs.is_stochastic and (
            self.stochastic != "multiplicative"
            and eqs.stochastic_type == "multiplicative"
        ):
            # The equations are marked as having multiplicative noise and the
            # current state updater does not support such equations. However,
            # it is possible that the equations do not use multiplicative noise
            # at all. They could depend on time via a function that is constant
            # over a single time step (most likely, a TimedArray). In that case
            # we can integrate the equations
            dt_value = variables["dt"].get_value()[0] if "dt" in variables else None
            for _, expr in substituted_expressions:
                _, stoch = expr.split_stochastic()
                if stoch is None:
                    continue
                # There could be more than one stochastic variable (e.g. xi_1, xi_2)
                for _, stoch_expr in stoch.items():
                    sympy_expr = str_to_sympy(stoch_expr.code)
                    # The equation really has multiplicative noise, if it depends
                    # on time (and not only via a function that is constant
                    # over dt), or if it depends on another variable defined
                    # via differential equations.
                    if not is_constant_over_dt(sympy_expr, variables, dt_value) or len(
                        stoch_expr.identifiers & eqs.diff_eq_names
                    ):
                        raise UnsupportedEquationsException(
                            "Cannot integrate "
                            "equations with "
                            "multiplicative noise with "
                            "this state updater."
                        )

        # Assign a value to all the model variables described by differential
        # equations
        for var, expr in substituted_expressions:
            RHS = self._generate_RHS(
                eqs,
                var,
                eq_variables,
                intermediate_vars,
                expr,
                non_stochastic_expr,
                stochastic_expr,
                stochastic_variables,
            )
            statements.append(f"_{var} = {RHS}")

        # Assign everything to the final variables
        for var, _ in substituted_expressions:
            statements.append(f"{var} = _{var}")

        return "\n".join(statements)


# ===============================================================================
# Excplicit state updaters
# ===============================================================================

# these objects can be used like functions because they are callable

#: Forward Euler state updater
euler = ExplicitStateUpdater(
    "x_new = x + dt * f(x,t) + g(x,t) * dW", stochastic="additive"
)

#: Second order Runge-Kutta method (midpoint method)
rk2 = ExplicitStateUpdater(
    """
    k = dt * f(x,t)
    x_new = x + dt*f(x +  k/2, t + dt/2)"""
)

#: Classical Runge-Kutta method (RK4)
rk4 = ExplicitStateUpdater(
    """
    k_1 = dt*f(x,t)
    k_2 = dt*f(x+k_1/2,t+dt/2)
    k_3 = dt*f(x+k_2/2,t+dt/2)
    k_4 = dt*f(x+k_3,t+dt)
    x_new = x+(k_1+2*k_2+2*k_3+k_4)/6
    """
)


def diagonal_noise(equations, variables):
    """
    Checks whether we deal with diagonal noise, i.e. one independent noise
    variable per variable.

    Raises
    ------
    UnsupportedEquationsException
        If the noise is not diagonal.
    """
    if not equations.is_stochastic:
        return

    stochastic_vars = []
    for _, expr in equations.get_substituted_expressions(variables):
        expr_stochastic_vars = expr.stochastic_variables
        if len(expr_stochastic_vars) > 1:
            # More than one stochastic variable --> no diagonal noise
            raise UnsupportedEquationsException(
                "Cannot integrate stochastic "
                "equations with non-diagonal "
                "noise with this state "
                "updater."
            )
        stochastic_vars.extend(expr_stochastic_vars)

    # If there's no stochastic variable is used in more than one equation, we
    # have diagonal noise
    if len(stochastic_vars) != len(set(stochastic_vars)):
        raise UnsupportedEquationsException(
            "Cannot integrate stochastic "
            "equations with non-diagonal "
            "noise with this state "
            "updater."
        )


#: Derivative-free Milstein method
milstein = ExplicitStateUpdater(
    """
    x_support = x + dt*f(x, t) + dt**.5 * g(x, t)
    g_support = g(x_support, t)
    k = 1/(2*dt**.5)*(g_support - g(x, t))*(dW**2)
    x_new = x + dt*f(x,t) + g(x, t) * dW + k
    """,
    stochastic="multiplicative",
    custom_check=diagonal_noise,
)

#: Stochastic Heun method (for multiplicative Stratonovic SDEs with non-diagonal
#: diffusion matrix)
heun = ExplicitStateUpdater(
    """
    x_support = x + g(x,t) * dW
    g_support = g(x_support,t+dt)
    x_new = x + dt*f(x,t) + .5*dW*(g(x,t)+g_support)
    """,
    stochastic="multiplicative",
)
