File: test_visitor.py

package info (click to toggle)
nmodl 0.6-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,992 kB
  • sloc: cpp: 28,492; javascript: 9,841; yacc: 2,804; python: 1,967; lex: 1,674; xml: 181; sh: 136; ansic: 37; makefile: 18; pascal: 7
file content (113 lines) | stat: -rw-r--r-- 3,751 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# ***********************************************************************
# Copyright (C) 2018-2022 Blue Brain Project
#
# This file is part of NMODL distributed under the terms of the GNU
# Lesser General Public License. See top-level LICENSE file for details.
# ***********************************************************************

import nmodl
from nmodl.dsl import ast, visitor
import pytest


def test_lookup_visitor(ch_ast):
    lookup_visitor = visitor.AstLookupVisitor()
    eqs = lookup_visitor.lookup(ch_ast, ast.AstNodeType.DIFF_EQ_EXPRESSION)
    eq_str = nmodl.dsl.to_nmodl(eqs[0])
    assert eq_str == "m' = mInf-m"


def test_lookup_visitor_any_node():
    """Ensure the AstLookupVisitor.lookup methods accept any node"""
    lookup_visitor = visitor.AstLookupVisitor(ast.AstNodeType.INTEGER)
    int42 = ast.Integer(42, None)

    eqs = lookup_visitor.lookup(int42)
    assert len(eqs) == 1

    eqs = lookup_visitor.lookup(int42, ast.AstNodeType.DOUBLE)
    assert len(eqs) == 0


def test_lookup_visitor_constructor(ch_ast):
    lookup_visitor = visitor.AstLookupVisitor(ast.AstNodeType.DIFF_EQ_EXPRESSION)
    eqs = lookup_visitor.lookup(ch_ast)
    eq_str = nmodl.dsl.to_nmodl(eqs[0])


def test_json_visitor(ch_ast):
    lookup_visitor = visitor.AstLookupVisitor(ast.AstNodeType.PRIME_NAME)
    primes = lookup_visitor.lookup(ch_ast)

    # test compact json
    prime_str = nmodl.dsl.to_nmodl(primes[0])
    prime_json = nmodl.dsl.to_json(primes[0], True)
    assert prime_json == '{"PrimeName":[{"String":[{"name":"m"}]},{"Integer":[{"name":"1"}]}]}'

    # test json with expanded keys
    result_json = nmodl.dsl.to_json(primes[0], compact=True, expand=True)
    expected_json = ('{"children":[{"children":[{"name":"m"}],'
                   '"name":"String"},{"children":[{"name":"1"}],'
                   '"name":"Integer"}],"name":"PrimeName"}')
    assert result_json == expected_json

    # test json with nmodl embedded
    result_json = nmodl.dsl.to_json(primes[0], compact=True, expand=True, add_nmodl=True)
    expected_json = ('{"children":[{"children":[{"name":"m"}],"name":"String","nmodl":"m"},'
                     '{"children":[{"name":"1"}],"name":"Integer","nmodl":"1"}],'
                     '"name":"PrimeName","nmodl":"m\'"}')
    assert result_json == expected_json


def test_custom_visitor(ch_ast):

    class StateVisitor(visitor.AstVisitor):
        def __init__(self):
            visitor.AstVisitor.__init__(self)
            self.in_state = False
            self.states = []

        def visit_state_block(self, node):
            self.in_state = True
            node.visit_children(self)
            self.in_state = False

        def visit_name(self, node):
            if self.in_state:
                self.states.append(nmodl.dsl.to_nmodl(node))

    myvisitor = StateVisitor()
    ch_ast.accept(myvisitor)

    assert len(myvisitor.states) is 2
    assert myvisitor.states[0] == "m"
    assert myvisitor.states[1] == "h"


def test_modify_ast():
    one_var = """NEURON {
    SUFFIX test
    RANGE x
}
    """
    class ModifyVisitor(visitor.AstVisitor):
        def __init__(self, old_name, new_name):
            visitor.AstVisitor.__init__(self)
            self.old_name = old_name
            self.new_name = new_name

        def visit_range_var(self, node):
            if nmodl.to_nmodl(node.name) == self.old_name:
                node.name.value = ast.String(self.new_name)
            node.visit_children(self)

    driver = nmodl.NmodlDriver()
    modast = driver.parse_string(one_var)
    mod_visitor = ModifyVisitor("x", "y")
    mod_visitor.visit_program(modast)
    one_var_after = """NEURON {
    SUFFIX test
    RANGE y
}
"""
    assert str(modast) == one_var_after