1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
|
/*************************************************************************
* Copyright (C) 2018-2022 Blue Brain Project
*
* This file is part of NMODL distributed under the terms of the GNU
* Lesser General Public License. See top-level LICENSE file for details.
*************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "ast/program.hpp"
#include "parser/nmodl_driver.hpp"
#include "test/unit/utils/test_utils.hpp"
#include "visitors/checkparent_visitor.hpp"
#include "visitors/constant_folder_visitor.hpp"
#include "visitors/loop_unroll_visitor.hpp"
#include "visitors/nmodl_visitor.hpp"
#include "visitors/symtab_visitor.hpp"
#include "visitors/visitor_utils.hpp"
using namespace nmodl;
using namespace visitor;
using namespace test;
using namespace test_utils;
using ast::AstNodeType;
using nmodl::parser::NmodlDriver;
//=============================================================================
// Loop unroll tests
//=============================================================================
std::string run_loop_unroll_visitor(const std::string& text) {
NmodlDriver driver;
const auto& ast = driver.parse_string(text);
SymtabVisitor().visit_program(*ast);
ConstantFolderVisitor().visit_program(*ast);
LoopUnrollVisitor().visit_program(*ast);
ConstantFolderVisitor().visit_program(*ast);
// check that, after visitor rearrangement, parents are still up-to-date
CheckParentVisitor().check_ast(*ast);
return to_nmodl(ast, {AstNodeType::DEFINE});
}
SCENARIO("Perform loop unrolling of FROM construct", "[visitor][unroll]") {
GIVEN("A loop with known iteration space") {
std::string input_nmodl = R"(
DEFINE N 2
PROCEDURE rates() {
LOCAL x[N]
FROM i=0 TO N {
x[i] = x[i] + 11
}
FROM i=(0+(0+1)) TO (N+2-1) {
x[(i+0)] = x[i+1] + 11
}
}
KINETIC state {
FROM i=1 TO N+1 {
~ ca[i] <-> ca[i+1] (DFree*frat[i+1]*1(um), DFree*frat[i+1]*1(um))
}
}
)";
std::string output_nmodl = R"(
PROCEDURE rates() {
LOCAL x[N]
{
x[0] = x[0]+11
x[1] = x[1]+11
x[2] = x[2]+11
}
{
x[1] = x[2]+11
x[2] = x[3]+11
x[3] = x[4]+11
}
}
KINETIC state {
{
~ ca[1] <-> ca[2] (DFree*frat[2]*1(um), DFree*frat[2]*1(um))
~ ca[2] <-> ca[3] (DFree*frat[3]*1(um), DFree*frat[3]*1(um))
~ ca[3] <-> ca[4] (DFree*frat[4]*1(um), DFree*frat[4]*1(um))
}
}
)";
THEN("Loop body gets correctly unrolled") {
auto result = run_loop_unroll_visitor(input_nmodl);
REQUIRE(reindent_text(output_nmodl) == reindent_text(result));
}
}
GIVEN("A nested loop") {
std::string input_nmodl = R"(
DEFINE N 1
PROCEDURE rates() {
LOCAL x[N]
FROM i=0 TO N {
FROM j=1 TO N+1 {
x[i] = x[i+j] + 1
}
}
}
)";
std::string output_nmodl = R"(
PROCEDURE rates() {
LOCAL x[N]
{
{
x[0] = x[1]+1
x[0] = x[2]+1
}
{
x[1] = x[2]+1
x[1] = x[3]+1
}
}
}
)";
THEN("Loop get unrolled recursively") {
auto result = run_loop_unroll_visitor(input_nmodl);
REQUIRE(reindent_text(output_nmodl) == reindent_text(result));
}
}
GIVEN("Loop with verbatim and unknown iteration space") {
std::string input_nmodl = R"(
DEFINE N 1
PROCEDURE rates() {
LOCAL x[N]
FROM i=((0+0)) TO (((N+0))) {
FROM j=1 TO k {
x[i] = x[i+k] + 1
}
}
FROM i=0 TO N {
VERBATIM ENDVERBATIM
}
}
)";
std::string output_nmodl = R"(
PROCEDURE rates() {
LOCAL x[N]
{
FROM j = 1 TO k {
x[0] = x[0+k]+1
}
FROM j = 1 TO k {
x[1] = x[1+k]+1
}
}
FROM i = 0 TO N {
VERBATIM ENDVERBATIM
}
}
)";
THEN("Only some loops get unrolled") {
auto result = run_loop_unroll_visitor(input_nmodl);
REQUIRE(reindent_text(output_nmodl) == reindent_text(result));
}
}
}
|