File: solve_block.cpp

package info (click to toggle)
nmodl 0.6-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,016 kB
  • sloc: cpp: 28,492; javascript: 9,841; yacc: 2,804; python: 1,971; lex: 1,674; xml: 181; sh: 136; ansic: 37; makefile: 17; pascal: 7
file content (125 lines) | stat: -rw-r--r-- 4,026 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
/*************************************************************************
 * Copyright (C) 2018-2022 Blue Brain Project
 *
 * This file is part of NMODL distributed under the terms of the GNU
 * Lesser General Public License. See top-level LICENSE file for details.
 *************************************************************************/

#include <catch2/catch_test_macros.hpp>

#include "ast/program.hpp"
#include "parser/nmodl_driver.hpp"
#include "test/unit/utils/test_utils.hpp"
#include "visitors/checkparent_visitor.hpp"
#include "visitors/neuron_solve_visitor.hpp"
#include "visitors/nmodl_visitor.hpp"
#include "visitors/solve_block_visitor.hpp"
#include "visitors/symtab_visitor.hpp"

using namespace nmodl;
using namespace visitor;
using namespace test;
using namespace test_utils;

using nmodl::parser::NmodlDriver;


//=============================================================================
// SolveBlock visitor tests
//=============================================================================

std::string run_solve_block_visitor(const std::string& text) {
    NmodlDriver driver;
    const auto& ast = driver.parse_string(text);
    SymtabVisitor().visit_program(*ast);
    NeuronSolveVisitor().visit_program(*ast);
    SolveBlockVisitor().visit_program(*ast);
    std::stringstream stream;
    NmodlPrintVisitor(stream).visit_program(*ast);

    // check that, after visitor rearrangement, parents are still up-to-date
    CheckParentVisitor().check_ast(*ast);

    return stream.str();
}

TEST_CASE("Solve ODEs using legacy NeuronSolveVisitor", "[visitor][solver]") {
    SECTION("SolveBlock add NrnState block") {
        GIVEN("Breakpoint block with single solve block in breakpoint") {
            std::string nmodl_text = R"(
            BREAKPOINT {
                SOLVE states METHOD cnexp
            }

            DERIVATIVE states {
                m' = (mInf-m)/mTau
            }
        )";

            std::string output_nmodl = R"(
            BREAKPOINT {
                SOLVE states METHOD cnexp
            }

            DERIVATIVE states {
                m = m+(1.0-exp(dt*((((-1.0)))/mTau)))*(-(((mInf))/mTau)/((((-1.0)))/mTau)-m)
            }

            NRN_STATE SOLVE states METHOD cnexp{
                m = m+(1.0-exp(dt*((((-1.0)))/mTau)))*(-(((mInf))/mTau)/((((-1.0)))/mTau)-m)
            }

        )";

            THEN("Single NrnState block gets added") {
                auto result = run_solve_block_visitor(nmodl_text);
                REQUIRE(reindent_text(output_nmodl) == reindent_text(result));
            }
        }

        GIVEN("Breakpoint block with two solve block in breakpoint") {
            std::string nmodl_text = R"(
            BREAKPOINT {
                SOLVE state1 METHOD cnexp
                SOLVE state2 METHOD cnexp
            }

            DERIVATIVE state1 {
                m' = (mInf-m)/mTau
            }

            DERIVATIVE state2 {
                h' = (mInf-h)/mTau
            }
        )";

            std::string output_nmodl = R"(
            BREAKPOINT {
                SOLVE state1 METHOD cnexp
                SOLVE state2 METHOD cnexp
            }

            DERIVATIVE state1 {
                m = m+(1.0-exp(dt*((((-1.0)))/mTau)))*(-(((mInf))/mTau)/((((-1.0)))/mTau)-m)
            }

            DERIVATIVE state2 {
                h = h+(1.0-exp(dt*((((-1.0)))/mTau)))*(-(((mInf))/mTau)/((((-1.0)))/mTau)-h)
            }

            NRN_STATE SOLVE state1 METHOD cnexp{
                m = m+(1.0-exp(dt*((((-1.0)))/mTau)))*(-(((mInf))/mTau)/((((-1.0)))/mTau)-m)
            }
            SOLVE state2 METHOD cnexp{
                h = h+(1.0-exp(dt*((((-1.0)))/mTau)))*(-(((mInf))/mTau)/((((-1.0)))/mTau)-h)
            }

        )";

            THEN("NrnState blok combining multiple solve nodes added") {
                auto result = run_solve_block_visitor(nmodl_text);
                REQUIRE(reindent_text(output_nmodl) == reindent_text(result));
            }
        }
    }
}