File: base.py

package info (click to toggle)
brian 2.9.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,872 kB
  • sloc: python: 51,820; cpp: 2,033; makefile: 108; sh: 72
file content (307 lines) | stat: -rw-r--r-- 11,761 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
"""
Base class for generating code in different programming languages, gives the
methods which should be overridden to implement a new language.
"""

from brian2.codegen.permutation_analysis import (
    OrderDependenceError,
    check_for_order_independence,
)
from brian2.codegen.translation import make_statements
from brian2.core.functions import Function
from brian2.core.variables import ArrayVariable
from brian2.utils.logger import get_logger
from brian2.utils.stringtools import get_identifiers

__all__ = ["CodeGenerator"]


logger = get_logger(__name__)


class CodeGenerator:
    """
    Base class for all languages.

    See definition of methods below.

    TODO: more details here
    """

    # Subclasses should override this
    class_name = ""

    def __init__(
        self,
        variables,
        variable_indices,
        owner,
        iterate_all,
        codeobj_class,
        name,
        template_name,
        override_conditional_write=None,
        allows_scalar_write=False,
    ):
        # We have to do the import here to avoid circular import dependencies.
        from brian2.devices.device import get_device

        self.device = get_device()
        self.variables = variables
        self.variable_indices = variable_indices
        self.func_name_replacements = {}
        for varname, var in variables.items():
            if isinstance(var, Function):
                if codeobj_class in var.implementations:
                    impl_name = var.implementations[codeobj_class].name
                    if impl_name is not None:
                        self.func_name_replacements[varname] = impl_name
        self.iterate_all = iterate_all
        self.codeobj_class = codeobj_class
        self.owner = owner
        if override_conditional_write is None:
            self.override_conditional_write = set()
        else:
            self.override_conditional_write = set(override_conditional_write)
        self.allows_scalar_write = allows_scalar_write
        self.name = name
        self.template_name = template_name
        # Gather the names of functions that should get an additional
        # "_vectorisation_idx" argument in the generated code. Take care
        # of storing their translated name (e.g. "_rand" instead of "rand")
        # if necessary
        self.auto_vectorise = {
            self.func_name_replacements.get(name, name)
            for name in self.variables
            if getattr(self.variables[name], "auto_vectorise", False)
        }

    @staticmethod
    def get_array_name(var, access_data=True):
        """
        Get a globally unique name for a `ArrayVariable`.

        Parameters
        ----------
        var : `ArrayVariable`
            The variable for which a name should be found.
        access_data : bool, optional
            For `DynamicArrayVariable` objects, specifying `True` here means the
            name for the underlying data is returned. If specifying `False`,
            the name of object itself is returned (e.g. to allow resizing).
        Returns
        -------
        name : str
            A uniqe name for `var`.
        """
        # We have to do the import here to avoid circular import dependencies.
        from brian2.devices.device import get_device

        device = get_device()
        return device.get_array_name(var, access_data=access_data)

    def translate_expression(self, expr):
        """
        Translate the given expression string into a string in the target
        language, returns a string.
        """
        raise NotImplementedError

    def translate_statement(self, statement):
        """
        Translate a single line `Statement` into the target language, returns
        a string.
        """
        raise NotImplementedError

    def determine_keywords(self):
        """
        A dictionary of values that is made available to the templated. This is
        used for example by the `CPPCodeGenerator` to set up all the supporting
        code
        """
        return {}

    def translate_one_statement_sequence(self, statements, scalar=False):
        raise NotImplementedError

    def translate_statement_sequence(self, scalar_statements, vector_statements):
        """
        Translate a sequence of `Statement` into the target language, taking
        care to declare variables, etc. if necessary.

        Returns a tuple ``(scalar_code, vector_code, kwds)`` where
        ``scalar_code`` is a list of the lines of code executed before the inner
        loop, ``vector_code`` is a list of the lines of code in the inner
        loop, and ``kwds`` is a dictionary of values that is made available to
        the template.
        """
        scalar_code = {}
        vector_code = {}
        for name, block in scalar_statements.items():
            scalar_code[name] = self.translate_one_statement_sequence(
                block, scalar=True
            )
        for name, block in vector_statements.items():
            vector_code[name] = self.translate_one_statement_sequence(
                block, scalar=False
            )

        kwds = self.determine_keywords()

        return scalar_code, vector_code, kwds

    def array_read_write(self, statements):
        """
        Helper function, gives the set of ArrayVariables that are read from and
        written to in the series of statements. Returns the pair read, write
        of sets of variable names.
        """
        variables = self.variables
        variable_indices = self.variable_indices
        read = set()
        write = set()
        for stmt in statements:
            ids = get_identifiers(stmt.expr)
            # if the operation is inplace this counts as a read.
            if stmt.inplace:
                ids.add(stmt.var)
            read = read.union(ids)
            if stmt.scalar or variable_indices[stmt.var] == "0":
                if stmt.op != ":=" and not self.allows_scalar_write:
                    raise SyntaxError(
                        f"Writing to scalar variable {stmt.var} not allowed in this"
                        " context."
                    )
                for name in ids:
                    if (
                        name in variables
                        and isinstance(variables[name], ArrayVariable)
                        and not (
                            variables[name].scalar or variable_indices[name] == "0"
                        )
                    ):
                        raise SyntaxError(
                            "Cannot write to scalar variable "
                            f"'{stmt.var}' with an expression "
                            f"referring to vector variable '{name}'"
                        )
            write.add(stmt.var)
        read = {
            varname
            for varname, var in list(variables.items())
            if isinstance(var, ArrayVariable) and varname in read
        }
        write = {
            varname
            for varname, var in list(variables.items())
            if isinstance(var, ArrayVariable) and varname in write
        }
        # Gather the indices stored as arrays (ignore _idx which is special)
        indices = set()
        indices |= {
            variable_indices[varname]
            for varname in read
            if not variable_indices[varname] in ("_idx", "0")
            and isinstance(variables[variable_indices[varname]], ArrayVariable)
        }
        indices |= {
            variable_indices[varname]
            for varname in write
            if not variable_indices[varname] in ("_idx", "0")
            and isinstance(variables[variable_indices[varname]], ArrayVariable)
        }
        # don't list arrays that are read explicitly and used as indices twice
        read -= indices
        return read, write, indices

    def get_conditional_write_vars(self):
        """
        Helper function, returns a dict of mappings ``(varname, condition_var_name)`` indicating that
        when ``varname`` is written to, it should only be when ``condition_var_name`` is ``True``.
        """
        conditional_write_vars = {}
        for varname, var in list(self.variables.items()):
            if getattr(var, "conditional_write", None) is not None:
                cvar = var.conditional_write
                cname = cvar.name
                if cname not in self.override_conditional_write:
                    conditional_write_vars[varname] = cname
        return conditional_write_vars

    def arrays_helper(self, statements):
        """
        Combines the two helper functions `array_read_write` and `get_conditional_write_vars`, and updates the
        ``read`` set.
        """
        read, write, indices = self.array_read_write(statements)
        conditional_write_vars = self.get_conditional_write_vars()
        read |= {var for var in write if var in conditional_write_vars}
        read |= {
            conditional_write_vars[var]
            for var in write
            if var in conditional_write_vars
        }
        return read, write, indices, conditional_write_vars

    def has_repeated_indices(self, statements):
        """
        Whether any of the statements potentially uses repeated indices (e.g.
        pre- or postsynaptic indices).
        """
        variables = self.variables
        variable_indices = self.variable_indices
        read, write, indices, conditional_write_vars = self.arrays_helper(statements)
        # Check whether we potentially deal with repeated indices (which will
        # be the case most importantly when we write to pre- or post-synaptic
        # variables in synaptic code)
        used_indices = {variable_indices[var] for var in write}
        all_unique = all(
            variables[index].unique
            for index in used_indices
            if index not in ("_idx", "0")
        )
        return not all_unique

    def translate(self, code, dtype):
        """
        Translates an abstract code block into the target language.
        """
        scalar_statements = {}
        vector_statements = {}
        for ac_name, ac_code in code.items():
            statements = make_statements(
                ac_code, self.variables, dtype, optimise=True, blockname=ac_name
            )
            scalar_statements[ac_name], vector_statements[ac_name] = statements
        for vs in vector_statements.values():
            # Check that the statements are meaningful independent on the order of
            # execution (e.g. for synapses)
            try:
                if self.has_repeated_indices(
                    vs
                ):  # only do order dependence if there are repeated indices
                    check_for_order_independence(
                        vs, self.variables, self.variable_indices
                    )
            except OrderDependenceError:
                # If the abstract code is only one line, display it in full
                if len(vs) <= 1:
                    error_msg = f"Abstract code: '{vs[0]}'\n"
                else:
                    error_msg = (
                        f"{len(vs)} lines of abstract code, first line is: '{vs[0]}'\n"
                    )
                logger.warn(
                    "Came across an abstract code block that may not be "
                    "well-defined: the outcome may depend on the "
                    "order of execution. You can ignore this warning if "
                    "you are sure that the order of operations does not "
                    "matter. " + error_msg
                )

        translated = self.translate_statement_sequence(
            scalar_statements, vector_statements
        )

        return translated