Skip to content

Commit

Permalink
Linked variables with string indices
Browse files Browse the repository at this point in the history
Enable linked variables for Synapses in certain cases
  • Loading branch information
mstimberg committed Jan 29, 2025
1 parent f1c0e3e commit 5fc3856
Show file tree
Hide file tree
Showing 5 changed files with 296 additions and 198 deletions.
280 changes: 207 additions & 73 deletions brian2/groups/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AuxiliaryVariable,
Constant,
DynamicArrayVariable,
LinkedVariable,
Subexpression,
Variable,
Variables,
Expand All @@ -38,6 +39,7 @@
from brian2.importexport.importexport import ImportExport
from brian2.units.fundamentalunits import (
DIMENSIONLESS,
DimensionMismatchError,
fail_for_dimension_mismatch,
get_unit,
)
Expand Down Expand Up @@ -420,85 +422,217 @@ def __getattr__(self, name):
except KeyError:
raise AttributeError(f"No attribute with name {name}")

def __setattr__(self, name, val, level=0):
def __setattr__(self, key, value, level=0):
# attribute access is switched off until this attribute is created by
# _enable_group_attributes
if not hasattr(self, "_group_attribute_access_active") or name in self.__dict__:
object.__setattr__(self, name, val)
elif (
name in self.__getattribute__("__dict__")
or name in self.__getattribute__("__class__").__dict__
):
# Makes sure that classes can override the "variables" mechanism
# with instance/class attributes and properties
return object.__setattr__(self, name, val)
elif name in self.variables:
var = self.variables[name]
if not isinstance(val, str):
if var.dim is DIMENSIONLESS:
fail_for_dimension_mismatch(
val,
var.dim,
"%s should be set with a dimensionless value, but got {value}"
% name,
value=val,
)
if not hasattr(self, "_group_attribute_access_active") or key in self.__dict__:
object.__setattr__(self, key, value)
elif key in getattr(self, "_linked_variables", set()):
if not isinstance(value, LinkedVariable):
raise ValueError(
"Cannot set a linked variable directly, link "
"it to another variable using 'linked_var'."
)
linked_var = value.variable

if isinstance(linked_var, DynamicArrayVariable):
raise NotImplementedError(
f"Linking to variable {linked_var.name} is "
"not supported, can only link to "
"state variables of fixed size."
)

eq = self.equations[key]
if eq.dim is not linked_var.dim:
raise DimensionMismatchError(
f"Unit of variable '{key}' does not "
"match its link target "
f"'{linked_var.name}'"
)

if not isinstance(linked_var, Subexpression):
var_length = len(linked_var)
else:
var_length = len(linked_var.owner)

if value.index is not None:
if isinstance(value.index, str):
if value.index not in self.variables:
raise ValueError(f"Index variable '{value.index}' not found.")
if (
self.variables.indices[value.index]
!= self.variables.default_index
):
raise ValueError(
f"Index variable '{value.index}' should use the default index itself."
)
if not np.issubdtype(self.variables[value.index].dtype, np.integer):
raise TypeError(
f"Index variable '{value.index}' should be an integer parameter."
)
index = value.index
else:
fail_for_dimension_mismatch(
val,
var.dim,
"%s should be set with a value with units %r, but got {value}"
% (name, get_unit(var.dim)),
value=val,
# Index arrays are not allowed for classes with dynamic size (Synapses)
if not isinstance(self.variables["N"], Constant):
raise TypeError(
"Cannot link a variable with an index array for a class with dynamic size – use a variable name instead."
)
try:
index_array = np.asarray(value.index)
if not np.issubdtype(index_array.dtype, int):
raise TypeError()
except TypeError:
raise TypeError(
"The index for a linked variable has to be an integer array"
)
size = len(index_array)
source_index = value.group.variables.indices[value.name]
if source_index not in ("_idx", "0"):
# we are indexing into an already indexed variable,
# calculate the indexing into the target variable
index_array = value.group.variables[source_index].get_value()[
index_array
]

if not index_array.ndim == 1 or size != len(self):
raise TypeError(
f"Index array for linked variable '{key}' "
"has to be a one-dimensional array of "
f"length {len(self)}, but has shape "
f"{index_array.shape!s}"
)
if min(index_array) < 0 or max(index_array) >= var_length:
raise ValueError(
f"Index array for linked variable {key} "
"contains values outside of the valid "
f"range [0, {var_length}["
)
self.variables.add_array(
f"_{key}_indices",
size=size,
dtype=index_array.dtype,
constant=True,
read_only=True,
values=index_array,
)
if var.read_only:
raise TypeError(f"Variable {name} is read-only.")
# Make the call X.var = ... equivalent to X.var[:] = ...
var.get_addressable_value_with_unit(name, self).set_item(
slice(None), val, level=level + 1
)
elif len(name) and name[-1] == "_" and name[:-1] in self.variables:
# no unit checking
var = self.variables[name[:-1]]
if var.read_only:
raise TypeError(f"Variable {name[:-1]} is read-only.")
# Make the call X.var = ... equivalent to X.var[:] = ...
var.get_addressable_value(name[:-1], self).set_item(
slice(None), val, level=level + 1
)
elif hasattr(self, name) or name.startswith("_"):
object.__setattr__(self, name, val)
else:
# Try to suggest the correct name in case of a typo
checker = SpellChecker(
[
varname
for varname, var in self.variables.items()
if not (varname.startswith("_") or var.read_only)
]
)
if name.endswith("_"):
suffix = "_"
name = name[:-1]
index = f"_{key}_indices"
else:
suffix = ""
error_msg = f'Could not find a state variable with name "{name}".'
suggestions = checker.suggest(name)
if len(suggestions) == 1:
(suggestion,) = suggestions
error_msg += f' Did you mean to write "{suggestion}{suffix}"?'
elif len(suggestions) > 1:
suggestion_str = ", ".join(
[f"'{suggestion}{suffix}'" for suggestion in suggestions]
)
error_msg += (
f" Did you mean to write any of the following: {suggestion_str} ?"
# The check at the end is to avoid the case that a size 1 NeuronGroup
# links to another NeuronGroup of size 1 and cannot do certain operations
# since the linked variable is considered scalar.
if linked_var.scalar or (
var_length == 1 and getattr(self, "_N", 0) != 1
):
index = "0"
else:
index = value.group.variables.indices[value.name]
if index == "_idx":
target_length = var_length
else:
target_length = len(value.group.variables[index])
# we need a name for the index that does not clash with
# other names and a reference to the index
new_index = f"_{value.name}_index_{index}"
self.variables.add_reference(new_index, value.group, index)
index = new_index

if len(self) != target_length:
raise ValueError(
f"Cannot link variable '{key}' to "
f"'{linked_var.name}', the size of the "
"target group does not match "
f"({len(self)} != {target_length}). You can "
"provide an indexing scheme with the "
"'index' keyword to link groups with "
"different sizes"
)
self.variables.add_reference(key, value.group, value.name, index=index)
source = (value.variable.owner.name,)
sourcevar = value.variable.name
log_msg = f"Setting {self.name}.{key} as a link to {source}.{sourcevar}"
if index is not None:
log_msg += f'(using "{index}" as index variable)'
logger.diagnostic(log_msg)
else:
if isinstance(value, LinkedVariable):
raise TypeError(
f"Cannot link variable '{key}', it has to be marked "
"as a linked variable with '(linked)' in the model "
"equations."
)
error_msg += (
" Use the add_attribute method if you intend to add "
"a new attribute to the object."
)
raise AttributeError(error_msg)
else:
if (
key in self.__getattribute__("__dict__")
or key in self.__getattribute__("__class__").__dict__
):
# Makes sure that classes can override the "variables" mechanism
# with instance/class attributes and properties
return object.__setattr__(self, key, value)
elif key in self.variables:
var = self.variables[key]
if not isinstance(value, str):
if var.dim is DIMENSIONLESS:
fail_for_dimension_mismatch(
value,
var.dim,
"%s should be set with a dimensionless value, but got {value}"
% key,
value=value,
)
else:
fail_for_dimension_mismatch(
value,
var.dim,
"%s should be set with a value with units %r, but got {value}"
% (key, get_unit(var.dim)),
value=value,
)
if var.read_only:
raise TypeError(f"Variable {key} is read-only.")
# Make the call X.var = ... equivalent to X.var[:] = ...
var.get_addressable_value_with_unit(key, self).set_item(
slice(None), value, level=level + 1
)
elif len(key) and key[-1] == "_" and key[:-1] in self.variables:
# no unit checking
var = self.variables[key[:-1]]
if var.read_only:
raise TypeError(f"Variable {key[:-1]} is read-only.")
# Make the call X.var = ... equivalent to X.var[:] = ...
var.get_addressable_value(key[:-1], self).set_item(
slice(None), value, level=level + 1
)
elif hasattr(self, key) or key.startswith("_"):
object.__setattr__(self, key, value)
else:
# Try to suggest the correct name in case of a typo
checker = SpellChecker(
[
varname
for varname, var in self.variables.items()
if not (varname.startswith("_") or var.read_only)
]
)
if key.endswith("_"):
suffix = "_"
key = key[:-1]
else:
suffix = ""
error_msg = f'Could not find a state variable with name "{key}".'
suggestions = checker.suggest(key)
if len(suggestions) == 1:
(suggestion,) = suggestions
error_msg += f' Did you mean to write "{suggestion}{suffix}"?'
elif len(suggestions) > 1:
suggestion_str = ", ".join(
[f"'{suggestion}{suffix}'" for suggestion in suggestions]
)
error_msg += f" Did you mean to write any of the following: {suggestion_str} ?"
error_msg += (
" Use the add_attribute method if you intend to add "
"a new attribute to the object."
)
raise AttributeError(error_msg)

def add_attribute(self, name):
"""
Expand Down
Loading

0 comments on commit 5fc3856

Please sign in to comment.