1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
|
/*************************************************************************
* Copyright (C) 2018-2022 Blue Brain Project
*
* This file is part of NMODL distributed under the terms of the GNU
* Lesser General Public License. See top-level LICENSE file for details.
*************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "ast/program.hpp"
#include "parser/nmodl_driver.hpp"
#include "test/unit/utils/test_utils.hpp"
#include "visitors/checkparent_visitor.hpp"
#include "visitors/neuron_solve_visitor.hpp"
#include "visitors/nmodl_visitor.hpp"
#include "visitors/solve_block_visitor.hpp"
#include "visitors/symtab_visitor.hpp"
using namespace nmodl;
using namespace visitor;
using namespace test;
using namespace test_utils;
using nmodl::parser::NmodlDriver;
//=============================================================================
// SolveBlock visitor tests
//=============================================================================
std::string run_solve_block_visitor(const std::string& text) {
NmodlDriver driver;
const auto& ast = driver.parse_string(text);
SymtabVisitor().visit_program(*ast);
NeuronSolveVisitor().visit_program(*ast);
SolveBlockVisitor().visit_program(*ast);
std::stringstream stream;
NmodlPrintVisitor(stream).visit_program(*ast);
// check that, after visitor rearrangement, parents are still up-to-date
CheckParentVisitor().check_ast(*ast);
return stream.str();
}
TEST_CASE("Solve ODEs using legacy NeuronSolveVisitor", "[visitor][solver]") {
SECTION("SolveBlock add NrnState block") {
GIVEN("Breakpoint block with single solve block in breakpoint") {
std::string nmodl_text = R"(
BREAKPOINT {
SOLVE states METHOD cnexp
}
DERIVATIVE states {
m' = (mInf-m)/mTau
}
)";
std::string output_nmodl = R"(
BREAKPOINT {
SOLVE states METHOD cnexp
}
DERIVATIVE states {
m = m+(1.0-exp(dt*((((-1.0)))/mTau)))*(-(((mInf))/mTau)/((((-1.0)))/mTau)-m)
}
NRN_STATE SOLVE states METHOD cnexp{
m = m+(1.0-exp(dt*((((-1.0)))/mTau)))*(-(((mInf))/mTau)/((((-1.0)))/mTau)-m)
}
)";
THEN("Single NrnState block gets added") {
auto result = run_solve_block_visitor(nmodl_text);
REQUIRE(reindent_text(output_nmodl) == reindent_text(result));
}
}
GIVEN("Breakpoint block with two solve block in breakpoint") {
std::string nmodl_text = R"(
BREAKPOINT {
SOLVE state1 METHOD cnexp
SOLVE state2 METHOD cnexp
}
DERIVATIVE state1 {
m' = (mInf-m)/mTau
}
DERIVATIVE state2 {
h' = (mInf-h)/mTau
}
)";
std::string output_nmodl = R"(
BREAKPOINT {
SOLVE state1 METHOD cnexp
SOLVE state2 METHOD cnexp
}
DERIVATIVE state1 {
m = m+(1.0-exp(dt*((((-1.0)))/mTau)))*(-(((mInf))/mTau)/((((-1.0)))/mTau)-m)
}
DERIVATIVE state2 {
h = h+(1.0-exp(dt*((((-1.0)))/mTau)))*(-(((mInf))/mTau)/((((-1.0)))/mTau)-h)
}
NRN_STATE SOLVE state1 METHOD cnexp{
m = m+(1.0-exp(dt*((((-1.0)))/mTau)))*(-(((mInf))/mTau)/((((-1.0)))/mTau)-m)
}
SOLVE state2 METHOD cnexp{
h = h+(1.0-exp(dt*((((-1.0)))/mTau)))*(-(((mInf))/mTau)/((((-1.0)))/mTau)-h)
}
)";
THEN("NrnState blok combining multiple solve nodes added") {
auto result = run_solve_block_visitor(nmodl_text);
REQUIRE(reindent_text(output_nmodl) == reindent_text(result));
}
}
}
}
|