File: test_ode.py

package info (click to toggle)
nmodl 0.6-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,992 kB
  • sloc: cpp: 28,492; javascript: 9,841; yacc: 2,804; python: 1,967; lex: 1,674; xml: 181; sh: 136; ansic: 37; makefile: 18; pascal: 7
file content (133 lines) | stat: -rw-r--r-- 4,849 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# ***********************************************************************
# Copyright (C) 2018-2022 Blue Brain Project
#
# This file is part of NMODL distributed under the terms of the GNU
# Lesser General Public License. See top-level LICENSE file for details.
# ***********************************************************************

from nmodl.ode import differentiate2c, integrate2c

import sympy as sp


def _equivalent(
    lhs, rhs, vars=["a", "b", "c", "d", "e", "f", "v", "w", "x", "y", "z", "t", "dt"]
):
    """Helper function to test equivalence of analytic expressions
    Analytic expressions can often be written in many different,
    but mathematically equivalent ways. This helper function uses
    SymPy to check if two analytic expressions are equivalent.
    If the expressions contain an "=", each is split into two expressions,
    and the two pairs of expressions are compared, i.e.
    _equivalent("a=b+c", "a=c+b") is the same thing as doing
    _equivalent("a", "a") and _equivalent("b+c", "c+b")
    Args:
        lhs: first expression, e.g. "x*(1-a)"
        rhs: second expression, e.g. "-a*x + x"
        vars: list of variables used in expressions, e.g. ["a", "x"]
    Returns:
        True if expressions are equivalent, False if they are not
    """
    lhs = lhs.replace("pow(", "Pow(")
    rhs = rhs.replace("pow(", "Pow(")
    sympy_vars = {var: sp.symbols(var, real=True) for var in vars}
    for l, r in zip(lhs.split("=", 1), rhs.split("=", 1)):
        eq_l = sp.sympify(l, locals=sympy_vars)
        eq_r = sp.sympify(r, locals=sympy_vars)
        difference = (eq_l - eq_r).evalf().simplify()
        if difference != 0:
            return False
    return True


def test_differentiate2c():

    # simple examples, no prev_expressions
    assert _equivalent(differentiate2c("0", "x", ""), "0")
    assert _equivalent(differentiate2c("x", "x", ""), "1")
    assert _equivalent(differentiate2c("a", "x", "a"), "0")
    assert _equivalent(differentiate2c("a*x", "x", "a"), "a")
    assert _equivalent(differentiate2c("a*x", "a", "x"), "x")
    assert _equivalent(differentiate2c("a*x", "y", {"x", "y"}), "0")
    assert _equivalent(differentiate2c("a*x + b*x*x", "x", {"a", "b"}), "2*b*x+a")
    assert _equivalent(
        differentiate2c("a*cos(x+b)", "x", {"a", "b"}), "-a * sin(b + x)"
    )
    assert _equivalent(
        differentiate2c("a*cos(x+b) + c*x*x", "x", {"a", "b", "c"}),
        "-a*sin(b+x) + 2*c*x",
    )

    # single prev_expression to substitute
    assert _equivalent(
        differentiate2c("a*x + b", "x", {"a", "b", "c", "d"}, ["c = sqrt(d)"]), "a"
    )
    assert _equivalent(
        differentiate2c("a*x + b", "x", {"a", "b"}, ["b = 2*x"]), "a + 2"
    )

    # multiple prev_eqs to substitute
    # (these statements should be in the same order as in the mod file)
    assert _equivalent(
        differentiate2c("a*x + b", "x", {"a", "b"}, ["b = 2*x", "a = -2*x*x"]), "-6*x*x+2"
    )
    assert _equivalent(
        differentiate2c("a*x + b", "x", {"a", "b"}, ["b = 2*x*x", "a = -2*x"]), "0"
    )

    # multiple prev_eqs to recursively substitute
    # note prev_eqs always substituted in reverse order
    # and only x-dependent rhs's are substituted, e.g. 'a' remains 'a' here:
    assert _equivalent(
        differentiate2c("a*x + b", "x", {"a", "b"}, ["a=3", "b = 2*a*x"]), "3*a"
    )
    assert _equivalent(
        differentiate2c(
            "a*x + b*c", "x", {"a", "b", "c"}, ["a=3", "b = 2*a*x", "c = a/x"]
        ),
        "a",
    )
    assert _equivalent(
        differentiate2c("-a*x + b*c", "x", {"a", "b", "c"}, ["b = 2*x*x", "c = a/x"]),
        "a",
    )
    assert _equivalent(
        differentiate2c(
            "(g1 + g2)*(v-e)",
            "v",
            {"g", "e", "g1", "g2", "c", "d"},
            ["g2 = sqrt(d) + 3", "g1 = 2*c", "g = g1 + g2"],
        ),
        "g",
    )


def test_integrate2c():

    # list of variables used for integrate2c
    var_list = ["x", "a", "b"]
    # pairs of (f(x), g(x))
    # where f(x) is the differential equation: dx/dt = f(x)
    # and g(x) is the solution: x(t+dt) = g(x(t))
    test_cases = [
        ("0", "x"),
        ("a", "x + a*dt"),
        ("a*x", "x*exp(a*dt)"),
        ("a*x+b", "(-b + (a*x + b)*exp(a*dt))/a"),
    ]
    for (eq, sol) in test_cases:
        assert _equivalent(
            integrate2c(f"x'={eq}", "dt", var_list, use_pade_approx=False), f"x = {sol}"
        )

    # repeat with solutions replaced with (1,1) Pade approximant
    pade_test_cases = [
        ("0", "x"),
        ("a", "x + a*dt"),
        ("a*x", "-x*(a*dt+2)/(a*dt-2)"),
        ("a*x+b", "-(a*dt*x+2*b*dt+2*x)/(a*dt-2)"),
    ]
    for (eq, sol) in pade_test_cases:
        assert _equivalent(
            integrate2c(f"x'={eq}", "dt", var_list, use_pade_approx=True), f"x = {sol}"
        )