1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161
|
"""
GSLCodeGenerators for code that uses the ODE solver provided by the GNU Scientific Library (GSL)
"""
import os
import re
import numpy as np
from brian2.codegen.generators import c_data_type
from brian2.codegen.permutation_analysis import (
OrderDependenceError,
check_for_order_independence,
)
from brian2.codegen.translation import make_statements
from brian2.core.functions import Function
from brian2.core.preferences import BrianPreference, PreferenceError, prefs
from brian2.core.variables import ArrayVariable, AuxiliaryVariable, Constant
from brian2.parsing.statements import parse_statement
from brian2.units.fundamentalunits import fail_for_dimension_mismatch
from brian2.utils.logger import get_logger
from brian2.utils.stringtools import get_identifiers, word_substitute
__all__ = ["GSLCodeGenerator", "GSLCPPCodeGenerator", "GSLCythonCodeGenerator"]
logger = get_logger(__name__)
def valid_gsl_dir(val):
"""
Validate given string to be path containing required GSL files.
"""
if val is None:
return True
if not isinstance(val, str):
raise PreferenceError(
f"Illegal value for GSL directory: {str(val)}, has to be str"
)
if not os.path.isdir(val):
raise PreferenceError(
f"Illegal value for GSL directory: {val}, has to be existing directory"
)
if any(
not os.path.isfile(os.path.join(val, "gsl", filename))
for filename in ["gsl_odeiv2.h", "gsl_errno.h", "gsl_matrix.h"]
):
raise PreferenceError(
f"Illegal value for GSL directory: '{val}', "
"has to contain gsl_odeiv2.h, gsl_errno.h "
"and gsl_matrix.h"
)
return True
prefs.register_preferences(
"GSL",
"Directory containing GSL code",
directory=BrianPreference(
validator=valid_gsl_dir,
docs=(
"Set path to directory containing GSL header files (gsl_odeiv2.h etc.)"
"\nIf this directory is already in Python's include (e.g. because of "
"conda installation), this path can be set to None."
),
default=None,
),
)
class GSLCodeGenerator:
"""
GSL code generator.
Notes
-----
Approach is to first let the already existing code generator for a target
language do the bulk of the translating from abstract_code to actual code.
This generated code is slightly adapted to render it GSL compatible.
The most critical part here is that the vector_code that is normally
contained in a loop in the ```main()``` is moved to the function ```_GSL_func```
that is sent to the GSL integrator. The variables used in the vector_code are
added to a struct named ```dataholder``` and their values are set from the
Brian namespace just before the scalar code block.
"""
def __init__(
self,
variables,
variable_indices,
owner,
iterate_all,
codeobj_class,
name,
template_name,
override_conditional_write=None,
allows_scalar_write=False,
):
self.generator = codeobj_class.original_generator_class(
variables,
variable_indices,
owner,
iterate_all,
codeobj_class,
name,
template_name,
override_conditional_write,
allows_scalar_write,
)
self.method_options = dict(owner.state_updater.method_options)
self.integrator = owner.state_updater.integrator
# default timestep to start with is the timestep of the NeuronGroup itself
self.method_options["dt_start"] = owner.dt.variable.get_value()[0]
self.variable_flags = owner.state_updater._gsl_variable_flags
def __getattr__(self, item):
return getattr(self.generator, item)
# A series of functions that should be overridden by child class:
def c_data_type(self, dtype):
"""
Get string version of object dtype that is attached to Brian variables. c
pp_generator already has this function, but the Cython generator does not,
but we need it for GSL code generation.
"""
return NotImplementedError
def initialize_array(self, varname, values):
"""
Initialize a static array with given floating point values. E.g. in C++,
when called with arguments ``array`` and ``[1.0, 3.0, 2.0]``, this
method should return ``double array[] = {1.0, 3.0, 2.0}``.
Parameters
----------
varname : str
The name of the array variable that should be initialized
values : list of float
The values that should be assigned to the array
Returns
-------
code : str
One or more lines of array initialization code.
"""
raise NotImplementedError
def var_init_lhs(self, var, type):
"""
Get string version of the left hand side of an initializing expression
Parameters
----------
var : str
type : str
Returns
-------
code : str
For cpp returns type + var, while for cython just var
"""
raise NotImplementedError
def unpack_namespace_single(self, var_obj, in_vector, in_scalar):
"""
Writes the code necessary to pull single variable out of the Brian
namespace into the generated code.
The code created is significantly different between cpp and cython,
so I decided to not make this function general
over all target languages (i.e. in contrast to most other functions
that only have syntactical differences)
"""
raise NotImplementedError
# GSL functions that are the same for all target languages:
def find_function_names(self):
"""
Return a list of used function names in the self.variables dictionary
Functions need to be ignored in the GSL translation process, because the
brian generator already sufficiently
dealt with them. However, the brian generator also removes them from the
variables dict, so there is no
way to check whether an identifier is a function after the brian
translation process. This function is called
before this translation process and the list of function names is stored
to be used in the GSL translation.
Returns
-------
function_names : list
list of strings that are function names used in the code
"""
variables = self.variables
return [
var for var, var_obj in variables.items() if isinstance(var_obj, Function)
]
def is_cpp_standalone(self):
"""
Check whether we're running with cpp_standalone.
Test if `get_device()` is instance `CPPStandaloneDevice`.
Returns
-------
is_cpp_standalone : bool
whether currently using cpp_standalone device
See Also
--------
is_constant_and_cpp_standalone : uses the returned value
"""
# imports here to avoid circular imports
from brian2.devices.cpp_standalone.device import CPPStandaloneDevice
from brian2.devices.device import get_device
device = get_device()
return isinstance(device, CPPStandaloneDevice)
def is_constant_and_cpp_standalone(self, var_obj):
"""Check whether self.cpp_standalone and variable is Constant.
This check is needed because in the case of using the cpp_standalone device we do not
want to apply our GSL variable conversion (var --> _GSL_dataholder.var), because the cpp_standalone
code generation process involves replacing constants with their actual value ('freezing').
This results in code that looks like (if for example var = 1.2): _GSL_dataholder.1.2 = 1.2 and _GSL_dataholder->1.2.
To prevent repetitive calls to get_device() etc. the outcome of is_cpp_standalone is saved.
Parameters
----------
var_obj : `Variable`
instance of brian Variable class describing the variable
Returns
-------
is_cpp_standalone : bool
whether the used device is cpp_standalone and the given variable is an instance of Constant
"""
if not hasattr(self, "cpp_standalone"):
self.cpp_standalone = self.is_cpp_standalone()
return isinstance(var_obj, Constant) and self.cpp_standalone
def find_differential_variables(self, code):
"""
Find the variables that were tagged _gsl_{var}_f{ind} and return var, ind pairs.
`GSLStateUpdater` tagged differential variables and here we extract the information given in these tags.
Parameters
----------
code : list of strings
A list of strings containing gsl tagged variables
Returns
-------
diff_vars : dict
A dictionary with variable names as keys and differential equation index as value
"""
diff_vars = {}
for expr_set in code:
for expr in expr_set.split("\n"):
expr = expr.strip(" ")
try:
lhs, op, rhs, comment = parse_statement(expr)
except ValueError:
pass
m = re.search("_gsl_(.+?)_f([0-9]*)$", lhs)
if m:
diff_vars[m.group(1)] = m.group(2)
return diff_vars
def diff_var_to_replace(self, diff_vars):
"""
Add differential variable-related strings that need to be replaced to go
from normal brian to GSL code
From the code generated by Brian's 'normal' generators (cpp_generator or
cython_generator a few bits of text need to be replaced to get GSL
compatible code. The bits of text related to differential equation
variables are put in the replacer dictionary in this function.
Parameters
----------
diff_vars : dict
dictionary with variables as keys and differential equation index as value
Returns
-------
to_replace : dict
dictionary with strings that need to be replaced as keys and the
strings that will replace them as values
"""
variables = self.variables
to_replace = {}
for var, diff_num in list(diff_vars.items()):
to_replace.update(self.var_replace_diff_var_lhs(var, diff_num))
var_obj = variables[var]
array_name = self.generator.get_array_name(var_obj, access_data=True)
idx_name = "_idx" # TODO: could be dynamic?
replace_what = f"{var} = {array_name}[{idx_name}]"
replace_with = f"{var} = _GSL_y[{diff_num}]"
to_replace[replace_what] = replace_with
return to_replace
def get_dimension_code(self, diff_num):
"""
Generate code for function that sets the dimension of the ODE system.
GSL needs to know how many differential variables there are in the
ODE system. Since the current approach is to have the code in the vector
loop the same for all simulations, this dimension is set by an external
function. The code for this set_dimension functon is written here.
It is assumed the code will be the same for each target language with the
exception of some syntactical differences
Parameters
----------
diff_num : int
Number of differential variables that describe the ODE system
Returns
-------
set_dimension_code : str
The code describing the target language function in a single string
"""
code = ["\n{start_declare}int set_dimension(size_t * dimension){open_function}"]
code += ["\tdimension[0] = %d{end_statement}" % diff_num]
code += ["\treturn GSL_SUCCESS{end_statement}{end_function}"]
return ("\n").join(code).format(**self.syntax)
def yvector_code(self, diff_vars):
"""
Generate code for function dealing with GSLs y vector.
The values of differential variables have to be transferred from
Brian's namespace to a vector that is given to GSL. The transferring
from Brian --> y and back from y --> Brian after integration happens in
separate functions. The code for these is written here.
Parameters
----------
diff_vars : dictionary
Dictionary containing variable names as keys (str) and differential
variable index as value
Returns
-------
yvector_code : str
The code for the two functions (``_fill_y_vector`` and
``_empty_y_vector``) as single string.
"""
fill_y = [
"\n{start_declare}int _fill_y_vector(_dataholder *"
"_GSL_dataholder, double * _GSL_y, int _idx){open_function}"
]
empty_y = [
"\n{start_declare}int _empty_y_vector(_dataholder * "
"_GSL_dataholder, double * _GSL_y, int _idx){"
"open_function}"
]
for var, diff_num in list(diff_vars.items()):
diff_num = int(diff_num)
array_name = self.generator.get_array_name(
self.variables[var], access_data=True
)
fill_y += [
"\t_GSL_y[%d] = _GSL_dataholder{access_pointer}%s[_idx]{end_statement}"
% (diff_num, array_name)
]
empty_y += [
"\t_GSL_dataholder{access_pointer}%s[_idx] = _GSL_y[%d]{end_statement}"
% (array_name, diff_num)
]
fill_y += ["\treturn GSL_SUCCESS{end_statement}{end_function}"]
empty_y += ["\treturn GSL_SUCCESS{end_statement}{end_function}"]
return ("\n").join(fill_y + empty_y).format(**self.syntax)
def make_function_code(self, lines):
"""
Add lines of GSL translated vector code to 'non-changing' _GSL_func code.
Adds nonchanging aspects of GSL _GSL_func code to lines of code
written somewhere else (`translate_vector_code`). Here these lines
are put between the non-changing parts of the code and the
target-language specific syntax is added.
Parameters
----------
lines : str
Code containing GSL version of equations
Returns
-------
function_code : str
code describing ``_GSL_func`` that is sent to GSL integrator.
"""
code = [
"\n{start_declare}int _GSL_func(double t, const double "
"_GSL_y[], double f[], void * params){open_function}"
"\n\t{start_declare}_dataholder * _GSL_dataholder = {open_cast}"
"_dataholder *{close_cast} params{end_statement}"
"\n\t{start_declare}int _idx = _GSL_dataholder{access_pointer}_idx"
"{end_statement}"
]
code += [lines]
code += ["\treturn GSL_SUCCESS{end_statement}{end_function}"]
return ("\n").join(code).format(**self.syntax)
def write_dataholder_single(self, var_obj):
"""
Return string declaring a single variable in the ``_dataholder`` struct.
Parameters
----------
var_obj : `Variable`
Returns
-------
code : str
string describing this variable object as required for the ``_dataholder`` struct
(e.g. ``double* _array_neurongroup_v``)
"""
dtype = self.c_data_type(var_obj.dtype)
if isinstance(var_obj, ArrayVariable):
pointer_name = self.get_array_name(var_obj, access_data=True)
try:
restrict = self.generator.restrict
except AttributeError:
restrict = ""
if var_obj.scalar or var_obj.size == 1:
restrict = ""
return f"{dtype}* {restrict} {pointer_name}{{end_statement}}"
else:
return f"{dtype} {var_obj.name}{{end_statement}}"
def write_dataholder(self, variables_in_vector):
"""
Return string with full code for _dataholder struct.
Parameters
----------
variables_in_vector : dict
dictionary containing variable name as key and `Variable` as value
Returns
-------
code : str
code for _dataholder struct
"""
code = ["\n{start_declare}struct _dataholder{open_struct}"]
code += ["\tint _idx{end_statement}"]
for var, var_obj in list(variables_in_vector.items()):
if (
var == "t"
or "_gsl" in var
or self.is_constant_and_cpp_standalone(var_obj)
):
continue
code += [f" {self.write_dataholder_single(var_obj)}"]
code += ["{end_struct}"]
return ("\n").join(code).format(**self.syntax)
def scale_array_code(self, diff_vars, method_options):
"""
Return code for definition of ``_GSL_scale_array`` in generated code.
Parameters
----------
diff_vars : dict
dictionary with variable name (str) as key and differential variable
index (int) as value
method_options : dict
dictionary containing integrator settings
Returns
-------
code : str
full code describing a function returning a array containing doubles
with the absolute errors for each differential variable (according
to their assigned index in the GSL StateUpdater)
"""
# get scale values per variable from method_options
abs_per_var = method_options["absolute_error_per_variable"]
abs_default = method_options["absolute_error"]
if not isinstance(abs_default, float):
raise TypeError(
"The absolute_error key in method_options should be "
f"a float. Was type {type(abs_default)}"
)
if abs_per_var is None:
diff_scale = {var: float(abs_default) for var in list(diff_vars.keys())}
elif isinstance(abs_per_var, dict):
diff_scale = {}
for var, error in list(abs_per_var.items()):
# first do some checks on input
if var not in diff_vars:
if var not in self.variables:
raise KeyError(
"absolute_error specified for variable that "
f"does not exist: {var}"
)
else:
raise KeyError(
"absolute_error specified for variable that is "
f"not being integrated: {var}"
)
fail_for_dimension_mismatch(
error,
self.variables[var],
"Unit of absolute_error_per_variable "
f"for variable {var} does not match "
"unit of variable itself",
)
# if all these are passed we can add the value for error in base units
diff_scale[var] = float(error)
# set the variables that are not mentioned to default value
for var in list(diff_vars.keys()):
if var not in abs_per_var:
diff_scale[var] = float(abs_default)
else:
raise TypeError(
"The absolute_error_per_variable key in method_options "
"should either be None or a dictionary "
"containing the error for each individual state variable. "
f"Was type {type(abs_per_var)}"
)
# write code
return self.initialize_array(
"_GSL_scale_array", [diff_scale[var] for var in sorted(diff_vars)]
)
def find_undefined_variables(self, statements):
r"""
Find identifiers that are not in ``self.variables`` dictionary.
Brian does not save the ``_lio_`` variables it uses anywhere. This is
problematic for our GSL implementation because we save the lio variables
in the ``_dataholder`` struct (for which we need the datatype of the
variables). This function adds the left hand side variables that are
used in the vector code to the variable dictionary as
`AuxiliaryVariable`\ s (all we need later is the datatype).
Parameters
----------
statements : list
list of statement objects (need to have the dtype attribute)
Notes
-----
I keep ``self.variables`` and ``other_variables`` separate so I can
distinguish what variables are in the Brian namespace and which ones are
defined in the code itself.
"""
variables = self.variables
other_variables = {}
for statement in statements:
var = statement.var
if var not in variables:
other_variables[var] = AuxiliaryVariable(var, dtype=statement.dtype)
return other_variables
def find_used_variables(self, statements, other_variables):
"""
Find all the variables used on the right hand side of the given
expressions.
Parameters
----------
statements : list
list of statement objects
Returns
-------
used_variables : dict
dictionary of variables that are used as variable name (str),
`Variable` pairs.
"""
variables = self.variables
used_variables = {}
for statement in statements:
rhs = statement.expr
for var in get_identifiers(rhs):
if var in self.function_names:
continue
try:
var_obj = variables[var]
except KeyError:
var_obj = other_variables[var]
used_variables[var] = var_obj # save as object because this has
# all needed info (dtype, name, isarray)
# I don't know a nicer way to do this, the above way misses write
# variables (e.g. not_refractory)..
read, write, _ = self.array_read_write(statements)
for var in read | write:
if var not in used_variables:
used_variables[var] = variables[var] # will always be array and
# thus exist in variables
return used_variables
def to_replace_vector_vars(self, variables_in_vector, ignore=frozenset()):
"""
Create dictionary containing key, value pairs with to be replaced text
to translate from conventional Brian to GSL.
Parameters
----------
variables_in_vector : dict
dictionary with variable name (str), `Variable` pairs of variables
occurring in vector code
ignore : set, optional
set of strings with variable names that should be ignored
Returns
-------
to_replace : dict
dictionary with strings that need to be replaced i.e. _lio_1 will be
_GSL_dataholder._lio_1 (in cython) or _GSL_dataholder->_lio_1 (cpp)
Notes
-----
t will always be added because GSL defines its own t.
i.e. for cpp: {'const t = _ptr_array_defaultclock_t[0];' : ''}
"""
access_pointer = self.syntax["access_pointer"]
to_replace = {}
t_in_code = None
for var, var_obj in list(variables_in_vector.items()):
if var_obj.name == "t":
t_in_code = var_obj
continue
if "_gsl" in var or var in ignore:
continue
if self.is_constant_and_cpp_standalone(var_obj):
# does not have to be processed by GSL generator
self.variables_to_be_processed.remove(var_obj.name)
continue
if isinstance(var_obj, ArrayVariable):
pointer_name = self.get_array_name(var_obj, access_data=True)
to_replace[pointer_name] = (
f"_GSL_dataholder{access_pointer}{pointer_name}"
)
else:
to_replace[var] = f"_GSL_dataholder{access_pointer}{var}"
# also make sure t declaration is replaced if in code
if t_in_code is not None:
t_declare = self.var_init_lhs("t", "const double ")
array_name = self.get_array_name(t_in_code, access_data=True)
end_statement = self.syntax["end_statement"]
replace_what = f"{t_declare} = {array_name}[0]{end_statement}"
to_replace[replace_what] = ""
self.variables_to_be_processed.remove("t")
return to_replace
def unpack_namespace(
self, variables_in_vector, variables_in_scalar, ignore=frozenset()
):
"""
Write code that unpacks Brian namespace to cython/cpp namespace.
For vector code this means putting variables in _dataholder (i.e.
_GSL_dataholder->var or _GSL_dataholder.var = ...)
Note that code is written so a variable could occur both in scalar and
vector code
Parameters
----------
variables_in_vector : dict
dictionary with variable name (str), `Variable` pairs of variables
occurring in vector code
variables_in_scalar : dict
dictionary with variable name (str), `Variable` pairs of variables
occurring in scalar code
ignore : set, optional
set of string names of variables that should be ignored
Returns
-------
unpack_namespace_code : str
code fragment unpacking the Brian namespace (setting variables in
the _dataholder struct in case of vector)
"""
code = []
for var, var_obj in list(self.variables.items()):
if var in ignore:
continue
if self.is_constant_and_cpp_standalone(var_obj):
continue
in_vector = var in variables_in_vector
in_scalar = var in variables_in_scalar
if in_vector:
self.variables_to_be_processed.remove(var)
code += [self.unpack_namespace_single(var_obj, in_vector, in_scalar)]
return ("\n").join(code)
def translate_vector_code(self, code_lines, to_replace):
"""
Translate vector code to GSL compatible code by substituting fragments
of code.
Parameters
----------
code_lines : list
list of strings describing the vector_code
to_replace: dict
dictionary with to be replaced strings (see to_replace_vector_vars
and to_replace_diff_vars)
Returns
-------
vector_code : str
New code that is now to be added to the function that is sent to the
GSL integrator
"""
code = []
for expr_set in code_lines:
for line in expr_set.split(
"\n"
): # every line seperate to make tabbing correct
code += [f" {line}"]
code = ("\n").join(code)
code = word_substitute(code, to_replace)
# special substitute because of limitations of regex word boundaries with
# variable[_idx]
for from_sub, to_sub in list(to_replace.items()):
m = re.search(r"\[(\w+)\];?$", from_sub)
if m:
code = re.sub(re.sub(r"\[", r"\[", from_sub), to_sub, code)
if "_gsl" in code:
raise AssertionError(
"Translation failed, _gsl still in code (should only "
"be tag, and should be replaced).\n"
f"Code:\n{code}"
)
return code
def translate_scalar_code(
self, code_lines, variables_in_scalar, variables_in_vector
):
"""
Translate scalar code: if calculated variables are used in the vector_code
their value is added to the variable in the _dataholder.
Parameters
----------
code_lines : list
list of strings containing scalar code
variables_in_vector : dict
dictionary with variable name (str), `Variable` pairs of variables
occurring in vector code
variables_in_scalar : dict
dictionary with variable name (str), `Variable` pairs of variables
occurring in scalar code
Returns
-------
scalar_code : str
code fragment that should be injected in the main before the loop
"""
code = []
for line in code_lines:
m = re.search(r"(\w+ = .*)", line)
try:
new_line = m.group(1)
var, op, expr, comment = parse_statement(new_line)
except (ValueError, AttributeError):
code += [line]
continue
if var in list(variables_in_scalar.keys()):
code += [line]
elif var in list(variables_in_vector.keys()):
if var == "t":
continue
try:
self.variables_to_be_processed.remove(var)
except KeyError:
raise AssertionError(
"Trying to process variable named %s by "
"putting its value in the _GSL_dataholder "
"based on scalar code, but the variable "
"has been processed already." % var
)
code += [f"_GSL_dataholder.{var} {op} {expr} {comment}"]
return "\n".join(code)
def add_gsl_variables_as_non_scalar(self, diff_vars):
"""
Add _gsl variables as non-scalar.
In `GSLStateUpdater` the differential equation variables are substituted
with GSL tags that describe the information needed to translate the
conventional Brian code to GSL compatible code. This function tells
Brian that the variables that contain these tags should always be vector
variables. If we don't do this, Brian renders the tag-variables as
scalar if no vector variables are used in the right hand side of the
expression.
Parameters
----------
diff_vars : dict
dictionary with variables as keys and differential equation index as
value
"""
for var, ind in list(diff_vars.items()):
name = f"_gsl_{var}_f{ind}"
self.variables[name] = AuxiliaryVariable(var, scalar=False)
def add_meta_variables(self, options):
if options["use_last_timestep"]:
try:
N = self.variables["N"].item()
self.owner.variables.add_array(
"_last_timestep",
size=N,
values=np.ones(N) * options["dt_start"],
dtype=np.float64,
)
except KeyError:
# has already been run
pass
self.variables["_last_timestep"] = self.owner.variables.get(
"_last_timestep"
)
pointer_last_timestep = (
f"{self.get_array_name(self.variables['_last_timestep'])}[_idx]"
)
else:
pointer_last_timestep = None
if options["save_failed_steps"]:
N = self.variables["N"].item()
try:
self.owner.variables.add_array("_failed_steps", size=N, dtype=np.int32)
except KeyError:
# has already been run
pass
self.variables["_failed_steps"] = self.owner.variables.get("_failed_steps")
pointer_failed_steps = (
f"{self.get_array_name(self.variables['_failed_steps'])}[_idx]"
)
else:
pointer_failed_steps = None
if options["save_step_count"]:
N = int(self.variables["N"].get_value())
try:
self.owner.variables.add_array("_step_count", size=N, dtype=np.int32)
except KeyError:
# has already been run
pass
self.variables["_step_count"] = self.owner.variables.get("_step_count")
pointer_step_count = (
f"{self.get_array_name(self.variables['_step_count'])}[_idx]"
)
else:
pointer_step_count = None
return {
"pointer_last_timestep": pointer_last_timestep,
"pointer_failed_steps": pointer_failed_steps,
"pointer_step_count": pointer_step_count,
}
def translate(
self, code, dtype
): # TODO: it's not so nice we have to copy the contents of this function..
"""
Translates an abstract code block into the target language.
"""
# first check if user code is not using variables that are also used by GSL
reserved_variables = [
"_dataholder",
"_fill_y_vector",
"_empty_y_vector",
"_GSL_dataholder",
"_GSL_y",
"_GSL_func",
]
if any([var in self.variables for var in reserved_variables]):
# import here to avoid circular import
raise ValueError(
f"The variables {str(reserved_variables)} are reserved for the GSL"
" internal code."
)
# if the following statements are not added, Brian translates the
# differential expressions in the abstract code for GSL to scalar statements
# in the case no non-scalar variables are used in the expression
diff_vars = self.find_differential_variables(list(code.values()))
self.add_gsl_variables_as_non_scalar(diff_vars)
# add arrays we want to use in generated code before self.generator.translate() so
# brian does namespace unpacking for us
pointer_names = self.add_meta_variables(self.method_options)
scalar_statements = {}
vector_statements = {}
for ac_name, ac_code in code.items():
statements = make_statements(
ac_code, self.variables, dtype, optimise=True, blockname=ac_name
)
scalar_statements[ac_name], vector_statements[ac_name] = statements
for vs in vector_statements.values():
# Check that the statements are meaningful independent on the order of
# execution (e.g. for synapses)
try:
if self.has_repeated_indices(
vs
): # only do order dependence if there are repeated indices
check_for_order_independence(
vs, self.generator.variables, self.generator.variable_indices
)
except OrderDependenceError:
# If the abstract code is only one line, display it in full
if len(vs) <= 1:
error_msg = f"Abstract code: '{vs[0]}'\n"
else:
error_msg = (
f"{len(vs)} lines of abstract code, first line is: '{vs[0]}'\n"
)
logger.warn(
"Came across an abstract code block that may not be "
"well-defined: the outcome may depend on the "
"order of execution. You can ignore this warning if "
"you are sure that the order of operations does not "
"matter. " + error_msg
)
# save function names because self.generator.translate_statement_sequence
# deletes these from self.variables but we need to know which identifiers
# we can safely ignore (i.e. we can ignore the functions because they are
# handled by the original generator)
self.function_names = self.find_function_names()
scalar_code, vector_code, kwds = self.generator.translate_statement_sequence(
scalar_statements, vector_statements
)
############ translate code for GSL
# first check if any indexing other than '_idx' is used (currently not supported)
for code_list in list(scalar_code.values()) + list(vector_code.values()):
for code in code_list:
m = re.search(r"\[(\w+)\]", code)
if m is not None:
if m.group(1) != "0" and m.group(1) != "_idx":
from brian2.stateupdaters.base import (
UnsupportedEquationsException,
)
raise UnsupportedEquationsException(
"Equations result in state "
"updater code with indexing "
"other than '_idx', which "
"is currently not supported "
"in combination with the "
"GSL stateupdater."
)
# differential variable specific operations
to_replace = self.diff_var_to_replace(diff_vars)
GSL_support_code = self.get_dimension_code(len(diff_vars))
GSL_support_code += self.yvector_code(diff_vars)
# analyze all needed variables; if not in self.variables: put in separate dic.
# also keep track of variables needed for scalar statements and vector statements
other_variables = self.find_undefined_variables(
scalar_statements[None] + vector_statements[None]
)
variables_in_scalar = self.find_used_variables(
scalar_statements[None], other_variables
)
variables_in_vector = self.find_used_variables(
vector_statements[None], other_variables
)
# so that _dataholder holds diff_vars as well, even if they don't occur
# in the actual statements
for var in list(diff_vars.keys()):
if var not in variables_in_vector:
variables_in_vector[var] = self.variables[var]
# let's keep track of the variables that eventually need to be added to
# the _GSL_dataholder somehow
self.variables_to_be_processed = list(variables_in_vector.keys())
# add code for _dataholder struct
GSL_support_code = self.write_dataholder(variables_in_vector) + GSL_support_code
# add e.g. _lio_1 --> _GSL_dataholder._lio_1 to replacer
to_replace.update(
self.to_replace_vector_vars(
variables_in_vector, ignore=list(diff_vars.keys())
)
)
# write statements that unpack (python) namespace to _dataholder struct
# or local namespace
GSL_main_code = self.unpack_namespace(
variables_in_vector, variables_in_scalar, ["t"]
)
# rewrite actual calculations described by vector_code and put them in _GSL_func
func_code = self.translate_one_statement_sequence(
vector_statements[None], scalar=False
)
GSL_support_code += self.make_function_code(
self.translate_vector_code(func_code, to_replace)
)
scalar_func_code = self.translate_one_statement_sequence(
scalar_statements[None], scalar=True
)
# rewrite scalar code, keep variables that are needed in scalar code normal
# and add variables to _dataholder for vector_code
GSL_main_code += f"\n{self.translate_scalar_code(scalar_func_code, variables_in_scalar, variables_in_vector)}"
if len(self.variables_to_be_processed) > 0:
raise AssertionError(
"Not all variables that will be used in the vector "
"code have been added to the _GSL_dataholder. This "
"might mean that the _GSL_func is using "
"uninitialized variables.\n"
"The unprocessed variables "
f"are: {self.variables_to_be_processed}"
)
scalar_code["GSL"] = GSL_main_code
kwds["define_GSL_scale_array"] = self.scale_array_code(
diff_vars, self.method_options
)
kwds["n_diff_vars"] = len(diff_vars)
kwds["GSL_settings"] = dict(self.method_options)
kwds["GSL_settings"]["integrator"] = self.integrator
kwds["support_code_lines"] += GSL_support_code.split("\n")
kwds["t_array"] = f"{self.get_array_name(self.variables['t'])}[0]"
kwds["dt_array"] = f"{self.get_array_name(self.variables['dt'])}[0]"
kwds["define_dt"] = "dt" not in variables_in_scalar
kwds["cpp_standalone"] = self.is_cpp_standalone()
for key, value in list(pointer_names.items()):
kwds[key] = value
return scalar_code, vector_code, kwds
class GSLCythonCodeGenerator(GSLCodeGenerator):
syntax = {
"end_statement": "",
"access_pointer": ".",
"start_declare": "cdef extern ",
"open_function": ":",
"open_struct": ":",
"end_function": "",
"end_struct": "",
"open_cast": "<",
"close_cast": ">",
"diff_var_declaration": "",
}
def c_data_type(self, dtype):
return c_data_type(dtype)
def initialize_array(self, varname, values):
value_list = ", ".join(repr(v) for v in values)
code = "cdef double {varname}[{n_values}]\n"
code += "{varname}[:] = [{value_list}]"
return code.format(varname=varname, value_list=value_list, n_values=len(values))
def var_replace_diff_var_lhs(self, var, ind):
return {f"_gsl_{var}_f{ind}": f"f[{ind}]"}
def var_init_lhs(self, var, type):
return var
def unpack_namespace_single(self, var_obj, in_vector, in_scalar):
code = []
if isinstance(var_obj, ArrayVariable):
array_name = self.generator.get_array_name(var_obj)
dtype = self.c_data_type(var_obj.dtype)
if in_vector:
code += [
f"_GSL_dataholder.{array_name} = <{dtype} *> _buf_{array_name}.data"
]
if in_scalar:
code += [f"{array_name} = <{dtype} *> _buf_{array_name}.data"]
else:
if in_vector:
code += [
f'_GSL_dataholder.{var_obj.name} = _namespace["{var_obj.name}"]'
]
if in_scalar:
code += [f'{var_obj.name} = _namespace["{var_obj.name}"]']
return "\n".join(code)
@staticmethod
def get_array_name(var, access_data=True):
# We have to do the import here to avoid circular import dependencies.
from brian2.codegen.generators.cython_generator import CythonCodeGenerator
return CythonCodeGenerator.get_array_name(var, access_data)
class GSLCPPCodeGenerator(GSLCodeGenerator):
def __getattr__(self, item):
return getattr(self.generator, item)
syntax = {
"end_statement": ";",
"access_pointer": "->",
"start_declare": 'extern "C" ',
"open_function": "\n{",
"open_struct": "\n{",
"end_function": "\n}",
"end_struct": "\n};",
"open_cast": "(",
"close_cast": ")",
"diff_var_declaration": "const scalar ",
}
def c_data_type(self, dtype):
return self.generator.c_data_type(dtype)
def initialize_array(self, varname, values):
value_list = ", ".join(repr(v) for v in values)
return f"double const {varname}[] = {{{value_list}}};"
def var_replace_diff_var_lhs(self, var, ind):
scalar_dtype = self.c_data_type(prefs.core.default_float_dtype)
f = f"f[{ind}]"
try:
if "unless refractory" in self.variable_flags[var]:
return {
f"_gsl_{var}_f{ind}": f,
f"{scalar_dtype} _gsl_{var}_f{ind};": "",
f"{scalar_dtype} {f};": "",
} # in case the replacement of
# _gsl_var_find to f[ind] happens
# first
except KeyError:
pass
return {f"const {scalar_dtype} _gsl_{var}_f{ind}": f}
def var_init_lhs(self, var, type):
return type + var
def unpack_namespace_single(self, var_obj, in_vector, in_scalar):
if isinstance(var_obj, ArrayVariable):
pointer_name = self.get_array_name(var_obj, access_data=True)
array_name = self.get_array_name(var_obj)
if in_vector:
return f"_GSL_dataholder.{pointer_name} = {array_name};"
else:
return ""
else:
if in_vector:
return f"_GSL_dataholder.{var_obj.name} = {var_obj.name};"
else:
return ""
|