File: spikegeneratorgroup.py

package info (click to toggle)
brian 2.9.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,872 kB
  • sloc: python: 51,820; cpp: 2,033; makefile: 108; sh: 72
file content (386 lines) | stat: -rw-r--r-- 15,012 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
"""
Module defining `SpikeGeneratorGroup`.
"""

import numpy as np

from brian2.core.functions import timestep
from brian2.core.spikesource import SpikeSource
from brian2.core.variables import Variables
from brian2.groups.group import CodeRunner, Group
from brian2.units.allunits import second
from brian2.units.fundamentalunits import Quantity, check_units
from brian2.utils.logger import get_logger

__all__ = ["SpikeGeneratorGroup"]


logger = get_logger(__name__)


class SpikeGeneratorGroup(Group, CodeRunner, SpikeSource):
    """
    SpikeGeneratorGroup(N, indices, times, dt=None, clock=None,
                        period=0*second, when='thresholds', order=0,
                        sorted=False, name='spikegeneratorgroup*',
                        codeobj_class=None)

    A group emitting spikes at given times.

    Parameters
    ----------
    N : int
        The number of "neurons" in this group
    indices : array of integers
        The indices of the spiking cells
    times : `Quantity`
        The spike times for the cells given in ``indices``. Has to have the
        same length as ``indices``.
    period : `Quantity`, optional
        If this is specified, it will repeat spikes with this period. A
        period of 0s means not repeating spikes.
    dt : `Quantity`, optional
        The time step to be used for the simulation. Cannot be combined with
        the `clock` argument.
    clock : `Clock`, optional
        The update clock to be used. If neither a clock, nor the `dt` argument
        is specified, the `defaultclock` will be used.
    when : str, optional
        When to run within a time step, defaults to the ``'thresholds'`` slot.
        See :ref:`scheduling` for possible values.
    order : int, optional
        The priority of of this group for operations occurring at the same time
        step and in the same scheduling slot. Defaults to 0.
    sorted : bool, optional
        Whether the given indices and times are already sorted. Set to ``True``
        if your events are already sorted (first by spike time, then by index),
        this can save significant time at construction if your arrays contain
        large numbers of spikes. Defaults to ``False``.

    Notes
    -----
    * If `sorted` is set to ``True``, the given arrays will not be copied
      (only affects runtime mode)..
    """

    @check_units(N=1, indices=1, times=second, period=second)
    def __init__(
        self,
        N,
        indices,
        times,
        dt=None,
        clock=None,
        period=0 * second,
        when="thresholds",
        order=0,
        sorted=False,
        name="spikegeneratorgroup*",
        codeobj_class=None,
    ):
        Group.__init__(self, dt=dt, clock=clock, when=when, order=order, name=name)

        # We store the indices and times also directly in the Python object,
        # this way we can use them for checks in `before_run` even in standalone
        # TODO: Remove this when the checks in `before_run` have been moved to the template
        #: Array of spiking neuron indices.
        self._neuron_index = None
        #: Array of spiking neuron times.
        self._spike_time = None
        #: "Dirty flag" that will be set when spikes are changed after the
        #: `before_run` check
        self._spikes_changed = True

        # Let other objects know that we emit spikes events
        self.events = {"spike": None}

        self.codeobj_class = codeobj_class

        if N < 1 or int(N) != N:
            raise TypeError("N has to be an integer >=1.")
        N = int(N)
        self.start = 0
        self.stop = N

        self.variables = Variables(self)
        self.variables.create_clock_variables(self._clock)

        indices, times = self._check_args(
            indices, times, period, N, sorted, self._clock.dt
        )

        self.variables.add_constant("N", value=N)
        self.variables.add_array(
            "period",
            dimensions=second.dim,
            size=1,
            constant=True,
            read_only=True,
            scalar=True,
            dtype=self._clock.variables["t"].dtype,
        )
        self.variables.add_arange("i", N)
        self.variables.add_dynamic_array(
            "spike_number",
            values=np.arange(len(indices)),
            size=len(indices),
            dtype=np.int32,
            read_only=True,
            constant=True,
            index="spike_number",
            unique=True,
        )
        self.variables.add_dynamic_array(
            "neuron_index",
            values=indices,
            size=len(indices),
            dtype=np.int32,
            index="spike_number",
            read_only=True,
            constant=True,
        )
        self.variables.add_dynamic_array(
            "spike_time",
            values=times,
            size=len(times),
            dimensions=second.dim,
            index="spike_number",
            read_only=True,
            constant=True,
            dtype=self._clock.variables["t"].dtype,
        )
        self.variables.add_dynamic_array(
            "_timebins",
            size=len(times),
            index="spike_number",
            read_only=True,
            constant=True,
            dtype=np.int32,
        )
        self.variables.add_array(
            "_period_bins",
            size=1,
            constant=True,
            read_only=True,
            scalar=True,
            dtype=np.int32,
        )
        self.variables.add_array("_spikespace", size=N + 1, dtype=np.int32)
        self.variables.add_array(
            "_lastindex", size=1, values=0, dtype=np.int32, read_only=True, scalar=True
        )

        #: Remember the dt we used the last time when we checked the spike bins
        #: to not repeat the work for multiple runs with the same dt
        self._previous_dt = None

        CodeRunner.__init__(
            self,
            self,
            code="",
            template="spikegenerator",
            clock=self._clock,
            when=when,
            order=order,
            name=None,
        )

        # Activate name attribute access
        self._enable_group_attributes()

        self.variables["period"].set_value(period)

    def _full_state(self):
        state = super()._full_state()
        # Store the internal information we use to decide whether to rebuild
        # the time bins
        state["_previous_dt"] = self._previous_dt
        state["_spikes_changed"] = self._spikes_changed
        return state

    def _restore_from_full_state(self, state):
        state = state.copy()  # copy to avoid errors for multiple restores
        self._previous_dt = state.pop("_previous_dt")
        self._spikes_changed = state.pop("_spikes_changed")
        super()._restore_from_full_state(state)

    def before_run(self, run_namespace):
        # Do some checks on the period vs. dt
        dt = self.dt_[:]  # make a copy
        period = self.period_
        if period < np.inf and period != 0:
            if period < dt:
                raise ValueError(
                    f"The period of '{self.name}' is {self.period[:]!s}, "
                    f"which is smaller than its dt of {dt*second!s}."
                )

        if self._spikes_changed:
            current_t = self.variables["t"].get_value().item()
            timesteps = timestep(self._spike_time, dt)
            current_step = timestep(current_t, dt)
            in_the_past = np.nonzero(timesteps < current_step)[0]
            if len(in_the_past):
                logger.warn(
                    "The SpikeGeneratorGroup contains spike times "
                    "earlier than the start time of the current run "
                    f"(t = {current_t*second!s}), these spikes will be "
                    "ignored.",
                    name_suffix="ignored_spikes",
                )
                self.variables["_lastindex"].set_value(in_the_past[-1] + 1)
            else:
                self.variables["_lastindex"].set_value(0)

        # Check that we don't have more than one spike per neuron in a time bin
        if self._previous_dt is None or dt != self._previous_dt or self._spikes_changed:
            # We shift all the spikes by a tiny amount to make sure that spikes
            # at exact multiples of dt do not end up in the previous time bin
            # This shift has to be quite significant relative to machine
            # epsilon, we use 1e-3 of the dt here
            shift = 1e-3 * dt
            timebins = np.asarray(
                np.asarray(self._spike_time + shift) / dt, dtype=np.int32
            )
            # time is already in sorted order, so it's enough to check if the condition
            # that timebins[i]==timebins[i+1] and self._neuron_index[i]==self._neuron_index[i+1]
            # is ever both true
            if (
                np.logical_and(np.diff(timebins) == 0, np.diff(self._neuron_index) == 0)
            ).any():
                raise ValueError(
                    f"Using a dt of {self.dt!s}, some neurons of "
                    f"SpikeGeneratorGroup '{self.name}' spike more than "
                    "once during a time step."
                )
            self.variables["_timebins"].set_value(timebins)
            period_bins = np.round(period / dt)
            max_int = np.iinfo(np.int32).max
            if period_bins > max_int:
                logger.warn(
                    f"Periods longer than {max_int} timesteps "
                    f"(={max_int*dt*second!s}) are not "
                    "supported, the period will therefore be "
                    "considered infinite. Set the period to 0*second "
                    "to avoid this "
                    "warning.",
                    "spikegenerator_long_period",
                )
                period = period_bins = 0
            if np.abs(period_bins * dt - period) > period * np.finfo(dt.dtype).eps:
                raise NotImplementedError(
                    f"The period of '{self.name}' is "
                    f"{self.period[:]!s}, which is "
                    "not an integer multiple of its dt "
                    f"of {dt*second!s}."
                )

            self.variables["_period_bins"].set_value(period_bins)

            self._previous_dt = dt
            self._spikes_changed = False

        super().before_run(run_namespace=run_namespace)

    @check_units(indices=1, times=second, period=second)
    def set_spikes(self, indices, times, period=0 * second, sorted=False):
        """
        set_spikes(indices, times, period=0*second, sorted=False)

        Change the spikes that this group will generate.

        This can be used to set the input for a second run of a model based on
        the output of a first run (if the input for the second run is already
        known before the first run, then all the information should simply be
        included in the initial `SpikeGeneratorGroup` initializer call,
        instead).

        Parameters
        ----------
        indices : array of integers
            The indices of the spiking cells
        times : `Quantity`
            The spike times for the cells given in ``indices``. Has to have the
            same length as ``indices``.
        period : `Quantity`, optional
            If this is specified, it will repeat spikes with this period. A
            period of 0s means not repeating spikes.
        sorted : bool, optional
            Whether the given indices and times are already sorted. Set to
            ``True`` if your events are already sorted (first by spike time,
            then by index), this can save significant time at construction if
            your arrays contain large numbers of spikes. Defaults to ``False``.
        """

        indices, times = self._check_args(
            indices, times, period, self.N, sorted, self.dt
        )

        self.variables["period"].set_value(period)
        self.variables["neuron_index"].resize(len(indices))
        self.variables["spike_time"].resize(len(indices))
        self.variables["spike_number"].resize(len(indices))
        self.variables["spike_number"].set_value(np.arange(len(indices)))
        self.variables["_timebins"].resize(len(indices))
        self.variables["neuron_index"].set_value(indices)
        self.variables["spike_time"].set_value(times)
        # _lastindex and _timebins will be set as part of before_run

    def _check_args(self, indices, times, period, N, sorted, dt):
        times = Quantity(times)
        if len(indices) != len(times):
            raise ValueError(
                "Length of the indices and times array must "
                f"match, but {len(indices)} != {len(times)}"
            )
        if period < 0 * second:
            raise ValueError("The period cannot be negative.")
        elif len(times) and period != 0 * second:
            period_bins = np.round(period / dt)
            # Note: we have to use the timestep function here, to use the same
            # binning as in the actual simulation
            max_bin = timestep(np.max(times), dt)
            if max_bin >= period_bins:
                raise ValueError(
                    "The period has to be greater than the maximum of the spike times"
                )
        if len(times) and np.min(times) < 0 * second:
            raise ValueError("Spike times cannot be negative")
        if len(indices) and (np.min(indices) < 0 or np.max(indices) >= N):
            raise ValueError(f"Indices have to lie in the interval [0, {int(N)}[")

        times = np.asarray(times)
        indices = np.asarray(indices)
        if not sorted:
            # sort times and indices first by time, then by indices
            sort_indices = np.lexsort((indices, times))
            indices = indices[sort_indices]
            times = times[sort_indices]

        # We store the indices and times also directly in the Python object,
        # this way we can use them for checks in `before_run` even in standalone
        # TODO: Remove this when the checks in `before_run` have been moved to the template
        self._neuron_index = indices
        self._spike_time = times
        self._spikes_changed = True

        return indices, times

    @property
    def spikes(self):
        """
        The spikes returned by the most recent thresholding operation.
        """
        # Note that we have to directly access the ArrayVariable object here
        # instead of using the Group mechanism by accessing self._spikespace
        # Using the latter would cut _spikespace to the length of the group
        spikespace = self.variables["_spikespace"].get_value()
        return spikespace[: spikespace[-1]]

    def __repr__(self):
        cls = self.__class__.__name__
        size = self.variables["neuron_index"].size
        return (
            f"{cls}({self.N}, indices=<length {size} array>, times=<length"
            f" {size} array>)"
        )