diff --git a/brian2/groups/group.py b/brian2/groups/group.py index 9a867e471..fd79745ba 100644 --- a/brian2/groups/group.py +++ b/brian2/groups/group.py @@ -30,6 +30,7 @@ AuxiliaryVariable, Constant, DynamicArrayVariable, + LinkedVariable, Subexpression, Variable, Variables, @@ -38,6 +39,7 @@ from brian2.importexport.importexport import ImportExport from brian2.units.fundamentalunits import ( DIMENSIONLESS, + DimensionMismatchError, fail_for_dimension_mismatch, get_unit, ) @@ -420,85 +422,217 @@ def __getattr__(self, name): except KeyError: raise AttributeError(f"No attribute with name {name}") - def __setattr__(self, name, val, level=0): + def __setattr__(self, key, value, level=0): # attribute access is switched off until this attribute is created by # _enable_group_attributes - if not hasattr(self, "_group_attribute_access_active") or name in self.__dict__: - object.__setattr__(self, name, val) - elif ( - name in self.__getattribute__("__dict__") - or name in self.__getattribute__("__class__").__dict__ - ): - # Makes sure that classes can override the "variables" mechanism - # with instance/class attributes and properties - return object.__setattr__(self, name, val) - elif name in self.variables: - var = self.variables[name] - if not isinstance(val, str): - if var.dim is DIMENSIONLESS: - fail_for_dimension_mismatch( - val, - var.dim, - "%s should be set with a dimensionless value, but got {value}" - % name, - value=val, - ) + if not hasattr(self, "_group_attribute_access_active") or key in self.__dict__: + object.__setattr__(self, key, value) + elif key in getattr(self, "_linked_variables", set()): + if not isinstance(value, LinkedVariable): + raise ValueError( + "Cannot set a linked variable directly, link " + "it to another variable using 'linked_var'." + ) + linked_var = value.variable + + if isinstance(linked_var, DynamicArrayVariable): + raise NotImplementedError( + f"Linking to variable {linked_var.name} is " + "not supported, can only link to " + "state variables of fixed size." + ) + + eq = self.equations[key] + if eq.dim is not linked_var.dim: + raise DimensionMismatchError( + f"Unit of variable '{key}' does not " + "match its link target " + f"'{linked_var.name}'" + ) + + if not isinstance(linked_var, Subexpression): + var_length = len(linked_var) + else: + var_length = len(linked_var.owner) + + if value.index is not None: + if isinstance(value.index, str): + if value.index not in self.variables: + raise ValueError(f"Index variable '{value.index}' not found.") + if ( + self.variables.indices[value.index] + != self.variables.default_index + ): + raise ValueError( + f"Index variable '{value.index}' should use the default index itself." + ) + if not np.issubdtype(self.variables[value.index].dtype, np.integer): + raise TypeError( + f"Index variable '{value.index}' should be an integer parameter." + ) + index = value.index else: - fail_for_dimension_mismatch( - val, - var.dim, - "%s should be set with a value with units %r, but got {value}" - % (name, get_unit(var.dim)), - value=val, + # Index arrays are not allowed for classes with dynamic size (Synapses) + if not isinstance(self.variables["N"], Constant): + raise TypeError( + "Cannot link a variable with an index array for a class with dynamic size – use a variable name instead." + ) + try: + index_array = np.asarray(value.index) + if not np.issubdtype(index_array.dtype, int): + raise TypeError() + except TypeError: + raise TypeError( + "The index for a linked variable has to be an integer array" + ) + size = len(index_array) + source_index = value.group.variables.indices[value.name] + if source_index not in ("_idx", "0"): + # we are indexing into an already indexed variable, + # calculate the indexing into the target variable + index_array = value.group.variables[source_index].get_value()[ + index_array + ] + + if not index_array.ndim == 1 or size != len(self): + raise TypeError( + f"Index array for linked variable '{key}' " + "has to be a one-dimensional array of " + f"length {len(self)}, but has shape " + f"{index_array.shape!s}" + ) + if min(index_array) < 0 or max(index_array) >= var_length: + raise ValueError( + f"Index array for linked variable {key} " + "contains values outside of the valid " + f"range [0, {var_length}[" + ) + self.variables.add_array( + f"_{key}_indices", + size=size, + dtype=index_array.dtype, + constant=True, + read_only=True, + values=index_array, ) - if var.read_only: - raise TypeError(f"Variable {name} is read-only.") - # Make the call X.var = ... equivalent to X.var[:] = ... - var.get_addressable_value_with_unit(name, self).set_item( - slice(None), val, level=level + 1 - ) - elif len(name) and name[-1] == "_" and name[:-1] in self.variables: - # no unit checking - var = self.variables[name[:-1]] - if var.read_only: - raise TypeError(f"Variable {name[:-1]} is read-only.") - # Make the call X.var = ... equivalent to X.var[:] = ... - var.get_addressable_value(name[:-1], self).set_item( - slice(None), val, level=level + 1 - ) - elif hasattr(self, name) or name.startswith("_"): - object.__setattr__(self, name, val) - else: - # Try to suggest the correct name in case of a typo - checker = SpellChecker( - [ - varname - for varname, var in self.variables.items() - if not (varname.startswith("_") or var.read_only) - ] - ) - if name.endswith("_"): - suffix = "_" - name = name[:-1] + index = f"_{key}_indices" else: - suffix = "" - error_msg = f'Could not find a state variable with name "{name}".' - suggestions = checker.suggest(name) - if len(suggestions) == 1: - (suggestion,) = suggestions - error_msg += f' Did you mean to write "{suggestion}{suffix}"?' - elif len(suggestions) > 1: - suggestion_str = ", ".join( - [f"'{suggestion}{suffix}'" for suggestion in suggestions] - ) - error_msg += ( - f" Did you mean to write any of the following: {suggestion_str} ?" + # The check at the end is to avoid the case that a size 1 NeuronGroup + # links to another NeuronGroup of size 1 and cannot do certain operations + # since the linked variable is considered scalar. + if linked_var.scalar or ( + var_length == 1 and getattr(self, "_N", 0) != 1 + ): + index = "0" + else: + index = value.group.variables.indices[value.name] + if index == "_idx": + target_length = var_length + else: + target_length = len(value.group.variables[index]) + # we need a name for the index that does not clash with + # other names and a reference to the index + new_index = f"_{value.name}_index_{index}" + self.variables.add_reference(new_index, value.group, index) + index = new_index + + if len(self) != target_length: + raise ValueError( + f"Cannot link variable '{key}' to " + f"'{linked_var.name}', the size of the " + "target group does not match " + f"({len(self)} != {target_length}). You can " + "provide an indexing scheme with the " + "'index' keyword to link groups with " + "different sizes" + ) + self.variables.add_reference(key, value.group, value.name, index=index) + source = (value.variable.owner.name,) + sourcevar = value.variable.name + log_msg = f"Setting {self.name}.{key} as a link to {source}.{sourcevar}" + if index is not None: + log_msg += f'(using "{index}" as index variable)' + logger.diagnostic(log_msg) + else: + if isinstance(value, LinkedVariable): + raise TypeError( + f"Cannot link variable '{key}', it has to be marked " + "as a linked variable with '(linked)' in the model " + "equations." ) - error_msg += ( - " Use the add_attribute method if you intend to add " - "a new attribute to the object." - ) - raise AttributeError(error_msg) + else: + if ( + key in self.__getattribute__("__dict__") + or key in self.__getattribute__("__class__").__dict__ + ): + # Makes sure that classes can override the "variables" mechanism + # with instance/class attributes and properties + return object.__setattr__(self, key, value) + elif key in self.variables: + var = self.variables[key] + if not isinstance(value, str): + if var.dim is DIMENSIONLESS: + fail_for_dimension_mismatch( + value, + var.dim, + "%s should be set with a dimensionless value, but got {value}" + % key, + value=value, + ) + else: + fail_for_dimension_mismatch( + value, + var.dim, + "%s should be set with a value with units %r, but got {value}" + % (key, get_unit(var.dim)), + value=value, + ) + if var.read_only: + raise TypeError(f"Variable {key} is read-only.") + # Make the call X.var = ... equivalent to X.var[:] = ... + var.get_addressable_value_with_unit(key, self).set_item( + slice(None), value, level=level + 1 + ) + elif len(key) and key[-1] == "_" and key[:-1] in self.variables: + # no unit checking + var = self.variables[key[:-1]] + if var.read_only: + raise TypeError(f"Variable {key[:-1]} is read-only.") + # Make the call X.var = ... equivalent to X.var[:] = ... + var.get_addressable_value(key[:-1], self).set_item( + slice(None), value, level=level + 1 + ) + elif hasattr(self, key) or key.startswith("_"): + object.__setattr__(self, key, value) + else: + # Try to suggest the correct name in case of a typo + checker = SpellChecker( + [ + varname + for varname, var in self.variables.items() + if not (varname.startswith("_") or var.read_only) + ] + ) + if key.endswith("_"): + suffix = "_" + key = key[:-1] + else: + suffix = "" + error_msg = f'Could not find a state variable with name "{key}".' + suggestions = checker.suggest(key) + if len(suggestions) == 1: + (suggestion,) = suggestions + error_msg += f' Did you mean to write "{suggestion}{suffix}"?' + elif len(suggestions) > 1: + suggestion_str = ", ".join( + [f"'{suggestion}{suffix}'" for suggestion in suggestions] + ) + error_msg += f" Did you mean to write any of the following: {suggestion_str} ?" + error_msg += ( + " Use the add_attribute method if you intend to add " + "a new attribute to the object." + ) + raise AttributeError(error_msg) def add_attribute(self, name): """ diff --git a/brian2/groups/neurongroup.py b/brian2/groups/neurongroup.py index 4e3b27491..b8f3e7f6e 100644 --- a/brian2/groups/neurongroup.py +++ b/brian2/groups/neurongroup.py @@ -13,12 +13,7 @@ from brian2.codegen.translation import analyse_identifiers from brian2.core.preferences import prefs from brian2.core.spikesource import SpikeSource -from brian2.core.variables import ( - DynamicArrayVariable, - LinkedVariable, - Subexpression, - Variables, -) +from brian2.core.variables import Variables from brian2.equations.equations import ( DIFFERENTIAL_EQUATION, PARAMETER, @@ -36,7 +31,6 @@ from brian2.units.allunits import second from brian2.units.fundamentalunits import ( DIMENSIONLESS, - DimensionMismatchError, Quantity, fail_for_dimension_mismatch, ) @@ -785,122 +779,6 @@ def set_event_schedule(self, event, when="after_thresholds", order=None): self.thresholder[event].when = when self.thresholder[event].order = order - def __setattr__(self, key, value): - # attribute access is switched off until this attribute is created by - # _enable_group_attributes - if not hasattr(self, "_group_attribute_access_active") or key in self.__dict__: - object.__setattr__(self, key, value) - elif key in self._linked_variables: - if not isinstance(value, LinkedVariable): - raise ValueError( - "Cannot set a linked variable directly, link " - "it to another variable using 'linked_var'." - ) - linked_var = value.variable - - if isinstance(linked_var, DynamicArrayVariable): - raise NotImplementedError( - f"Linking to variable {linked_var.name} is " - "not supported, can only link to " - "state variables of fixed size." - ) - - eq = self.equations[key] - if eq.dim is not linked_var.dim: - raise DimensionMismatchError( - f"Unit of variable '{key}' does not " - "match its link target " - f"'{linked_var.name}'" - ) - - if not isinstance(linked_var, Subexpression): - var_length = len(linked_var) - else: - var_length = len(linked_var.owner) - - if value.index is not None: - try: - index_array = np.asarray(value.index) - if not np.issubdtype(index_array.dtype, int): - raise TypeError() - except TypeError: - raise TypeError( - "The index for a linked variable has to be an integer array" - ) - size = len(index_array) - source_index = value.group.variables.indices[value.name] - if source_index not in ("_idx", "0"): - # we are indexing into an already indexed variable, - # calculate the indexing into the target variable - index_array = value.group.variables[source_index].get_value()[ - index_array - ] - - if not index_array.ndim == 1 or size != len(self): - raise TypeError( - f"Index array for linked variable '{key}' " - "has to be a one-dimensional array of " - f"length {len(self)}, but has shape " - f"{index_array.shape!s}" - ) - if min(index_array) < 0 or max(index_array) >= var_length: - raise ValueError( - f"Index array for linked variable {key} " - "contains values outside of the valid " - f"range [0, {var_length}[" - ) - self.variables.add_array( - f"_{key}_indices", - size=size, - dtype=index_array.dtype, - constant=True, - read_only=True, - values=index_array, - ) - index = f"_{key}_indices" - else: - if linked_var.scalar or (var_length == 1 and self._N != 1): - index = "0" - else: - index = value.group.variables.indices[value.name] - if index == "_idx": - target_length = var_length - else: - target_length = len(value.group.variables[index]) - # we need a name for the index that does not clash with - # other names and a reference to the index - new_index = f"_{value.name}_index_{index}" - self.variables.add_reference(new_index, value.group, index) - index = new_index - - if len(self) != target_length: - raise ValueError( - f"Cannot link variable '{key}' to " - f"'{linked_var.name}', the size of the " - "target group does not match " - f"({len(self)} != {target_length}). You can " - "provide an indexing scheme with the " - "'index' keyword to link groups with " - "different sizes" - ) - - self.variables.add_reference(key, value.group, value.name, index=index) - source = (value.variable.owner.name,) - sourcevar = value.variable.name - log_msg = f"Setting {self.name}.{key} as a link to {source}.{sourcevar}" - if index is not None: - log_msg += f'(using "{index}" as index variable)' - logger.diagnostic(log_msg) - else: - if isinstance(value, LinkedVariable): - raise TypeError( - f"Cannot link variable '{key}', it has to be marked " - "as a linked variable with '(linked)' in the model " - "equations." - ) - else: - Group.__setattr__(self, key, value, level=1) - def __getitem__(self, item): start, stop = to_start_stop(item, self._N) diff --git a/brian2/synapses/synapses.py b/brian2/synapses/synapses.py index 0ff46b8bb..57b9e33bd 100644 --- a/brian2/synapses/synapses.py +++ b/brian2/synapses/synapses.py @@ -854,7 +854,7 @@ def __init__( { DIFFERENTIAL_EQUATION: ["event-driven", "clock-driven"], SUBEXPRESSION: ["summed", "shared", "constant over dt"], - PARAMETER: ["constant", "shared"], + PARAMETER: ["constant", "shared", "linked"], }, incompatible_flags=[ ("event-driven", "clock-driven"), @@ -944,6 +944,8 @@ def __init__( else: self.event_driven = None + self._linked_variables = set() + self._create_variables(model, user_dtype=dtype) self.equations = Equations(continuous) @@ -1372,7 +1374,10 @@ def _create_variables(self, equations, user_dtype=None): check_identifier_pre_post(eq.varname) constant = "constant" in eq.flags shared = "shared" in eq.flags - if shared: + linked = "linked" in eq.flags + if linked: + self._linked_variables.add(eq.varname) + elif shared: self.variables.add_array( eq.varname, size=1, diff --git a/brian2/tests/test_neurongroup.py b/brian2/tests/test_neurongroup.py index 28b094f11..6796a14d3 100644 --- a/brian2/tests/test_neurongroup.py +++ b/brian2/tests/test_neurongroup.py @@ -462,6 +462,30 @@ def test_linked_variable_indexed(): assert_allclose(G.y[:], np.arange(10)[::-1] * 0.1) +@pytest.mark.codegen_independent +def test_linked_variable_index_variable(): + """ + Test linking a variable with an index specified as an array + """ + G = NeuronGroup( + 10, + """ + x : 1 + index_var : integer + not_an_index_var : 1 + y : 1 (linked) + """, + ) + + G.x = np.arange(10) * 0.1 + with pytest.raises(TypeError): + G.y = linked_var(G.x, index="not_an_index_var") + G.y = linked_var(G.x, index="index_var") + G.index_var = np.arange(10)[::-1] + # G.y should refer to an inverted version of G.x + assert_allclose(G.y[:], np.arange(10)[::-1] * 0.1) + + @pytest.mark.codegen_independent def test_linked_variable_repeat(): """ diff --git a/brian2/tests/test_synapses.py b/brian2/tests/test_synapses.py index 7e00063f9..4e4813693 100644 --- a/brian2/tests/test_synapses.py +++ b/brian2/tests/test_synapses.py @@ -1626,6 +1626,63 @@ def test_summed_variables_linked_variables(): net.run(0 * ms) +@pytest.mark.codegen_independent +def test_linked_to_shared_variables(): + source1 = NeuronGroup(1, "dx/dt = -x / (10*ms) : 1") + source1.x = "rand()" + source2 = NeuronGroup(10, "x : 1 (shared)") + source2.x = "rand()" + mon = StateMonitor(source1, "x", record=True) + group = NeuronGroup(10, "") + syn = Synapses( + group, + group, + """x1 : 1 (linked) + x2 : 1 (linked) + """, + ) + syn.x1 = linked_var(source1.x) + syn.x2 = linked_var(source2.x) + syn.connect(i=[0, 5], j=[3, 6]) + mon_syn = StateMonitor(syn, ["x1", "x2"], record=True) + run(2 * defaultclock.dt) + assert_allclose(mon.x[0], mon_syn.x1[0]) + assert_allclose(mon.x[0], mon_syn.x1[1]) + assert_allclose(source2.x, mon_syn.x2[0]) + assert_allclose(source2.x, mon_syn.x2[1]) + + +@pytest.mark.codegen_independent +def test_linked_to_non_shared_variables(): + source = NeuronGroup(10, "dx/dt = -x / (10*ms) : 1") + source.x = "rand()" + mon = StateMonitor(source, "x", record=True) + group = NeuronGroup(10, "y: 1") + syn = Synapses( + group, + group, + """x : 1 (linked) + x_ind: integer + not_an_index : 1 + """, + ) + syn.connect(i=[0, 5], j=[3, 6]) + with pytest.raises(TypeError): + syn.x = linked_var(source.x, index=[0, 1]) + with pytest.raises(ValueError): + syn.x = linked_var(source.x, index="does_not_exist") + with pytest.raises(ValueError): + syn.x = linked_var(source.x, index="y_post") + with pytest.raises(TypeError): + syn.x = linked_var(source.x, index="not_an_index") + syn.x = linked_var(source.x, index="x_ind") + syn.x_ind = [3, 5] + mon_syn = StateMonitor(syn, "x", record=True) + run(2 * defaultclock.dt) + assert_allclose(mon.x[3], mon_syn.x[0]) + assert_allclose(mon.x[5], mon_syn.x[1]) + + def test_scalar_parameter_access(): G = NeuronGroup( 10,