From 244daf62cb2d7849c5f7c086adac4fdc0f0554a7 Mon Sep 17 00:00:00 2001 From: Hafez Ghaemi Date: Tue, 18 Jan 2022 21:11:24 +0100 Subject: [PATCH 1/4] Add local connection classes and MNIST examples --- bindsnet/analysis/plotting.py | 69 ++- bindsnet/learning/learning.py | 1093 ++++++++++++++++++++++++++++++++- bindsnet/network/topology.py | 484 ++++++++++++++- bindsnet/utils.py | 63 ++ examples/mnist/conv_mnist.py | 4 +- examples/mnist/loc1d_mnist.py | 170 +++++ examples/mnist/loc2d_mnist.py | 182 ++++++ examples/mnist/loc3d_mnist.py | 172 ++++++ 8 files changed, 2229 insertions(+), 8 deletions(-) create mode 100644 examples/mnist/loc1d_mnist.py create mode 100644 examples/mnist/loc2d_mnist.py create mode 100644 examples/mnist/loc3d_mnist.py diff --git a/bindsnet/analysis/plotting.py b/bindsnet/analysis/plotting.py index d7871274..b6909ee4 100644 --- a/bindsnet/analysis/plotting.py +++ b/bindsnet/analysis/plotting.py @@ -9,7 +9,7 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable from torch.nn.modules.utils import _pair -from bindsnet.utils import reshape_conv2d_weights, reshape_locally_connected_weights +from bindsnet.utils import reshape_conv2d_weights, reshape_locally_connected_weights, reshape_local_connection_2d_weights plt.ion() @@ -377,6 +377,71 @@ def plot_locally_connected_weights( return im +def plot_local_connection_2d_weights(lc : object, + input_channel: int = 0, + output_channel: int = None, + im: Optional[AxesImage] = None, + lines: bool = True, + figsize: Tuple[int, int] = (5, 5), + cmap: str = "hot_r", + color: str='r', + ) -> AxesImage: + # language=rst + """ + Plot a connection weight matrix of a :code:`Connection` with `locally connected + structure _. + :param lc: An object of the class LocalConnection2D + :param input_channel: The input channel to plot its corresponding weights, default is the first channel + :param output_channel: If not None, will only plot the weights corresponding to this output channel (filter) + :param lines: Indicates whether or not draw horizontal and vertical lines separating input regions. + :param figsize: Horizontal and vertical figure size in inches. + :param cmap: Matplotlib colormap. + :return: ``ims, axes``: Used for re-drawing the plots. + """ + + n_sqrt = int(np.ceil(np.sqrt(lc.n_filters))) + sel_slice = lc.w.view(lc.in_channels, lc.n_filters, lc.conv_size[0], lc.conv_size[1], lc.kernel_size[0], lc.kernel_size[1]).cpu() + input_size = _pair(int(np.sqrt(lc.source.n))) + if output_channel is None: + sel_slice = sel_slice[input_channel, ...] + reshaped = reshape_local_connection_2d_weights(sel_slice, lc.n_filters, lc.kernel_size, lc.conv_size, input_size) + else: + sel_slice = sel_slice[input_channel, output_channel, ...] + sel_slice = sel_slice.unsqueeze(0) + reshaped = reshape_local_connection_2d_weights(sel_slice, 1, lc.kernel_size, lc.conv_size, input_size) + if im == None: + fig, ax = plt.subplots(figsize=figsize) + + im = ax.imshow(reshaped.cpu(), cmap=cmap, vmin=lc.wmin, vmax=lc.wmax) + div = make_axes_locatable(ax) + cax = div.append_axes("right", size="5%", pad=0.05) + + if lines and output_channel is None: + for i in range( + n_sqrt * lc.kernel_size[0], + n_sqrt * lc.conv_size[0] * lc.kernel_size[0], + n_sqrt * lc.kernel_size[0], + ): + ax.axhline(i - 0.5, color=color, linestyle="--") + + for i in range( + n_sqrt * lc.kernel_size[1], + n_sqrt * lc.conv_size[1] * lc.kernel_size[1], + n_sqrt * lc.kernel_size[1], + ): + ax.axvline(i - 0.5, color=color, linestyle="--") + + ax.set_xticks(()) + ax.set_yticks(()) + ax.set_aspect("auto") + + plt.colorbar(im, cax=cax) + fig.tight_layout() + + else: + im.set_data(reshaped.cpu()) + return im + def plot_assignments( assignments: torch.Tensor, @@ -760,3 +825,5 @@ def plot_voltages( plt.tight_layout() return ims, axes + + diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index 6c7725f1..fa3e15f6 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -16,6 +16,9 @@ Conv2dConnection, Conv3dConnection, LocalConnection, + LocalConnection1D, + LocalConnection2D, + LocalConnection3D, ) @@ -176,6 +179,12 @@ def __init__( if isinstance(connection, (Connection, LocalConnection)): self.update = self._connection_update + elif isinstance(connection, LocalConnection1D): + self.update = self._local_connection1d_update + elif isinstance(connection, LocalConnection2D): + self.update = self._local_connection2d_update + elif isinstance(connection, LocalConnection3D): + self.update = self._local_connection3d_update elif isinstance(connection, Conv1dConnection): self.update = self._conv1d_connection_update elif isinstance(connection, Conv2dConnection): @@ -186,6 +195,163 @@ def __init__( raise NotImplementedError( "This learning rule is not supported for this Connection type." ) + + def _local_connection1d_update(self, **kwargs) -> None: + # language=rst + """ + Post-pre learning rule for ``LocalConnection1D`` subclass of + ``AbstractConnection`` class. + """ + # Get LC layer parameters. + stride = self.connection.stride + batch_size = self.source.batch_size + kernel_height = self.connection.kernel_size + in_channels = self.connection.source.shape[0] + out_channels = self.connection.n_filters + height_out = self.connection.conv_size + + + target_x = self.target.x.reshape(batch_size, out_channels * height_out, 1) + target_x = target_x * torch.eye(out_channels * height_out).to(self.connection.w.device) + source_s = self.source.s.type(torch.float).unfold(-1, kernel_height, stride).reshape( + batch_size, + height_out, + in_channels * kernel_height, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + + target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out,1) + target_s = target_s * torch.eye(out_channels * height_out).to(self.connection.w.device) + source_x = self.source.x.unfold(-1, kernel_height, stride).reshape( + batch_size, + height_out, + in_channels * kernel_height, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + # Pre-synaptic update. + if self.nu[0]: + pre = self.reduction(torch.bmm(target_x,source_s), dim=0) + self.connection.w -= self.nu[0] * pre.view(self.connection.w.size()) + # Post-synaptic update. + if self.nu[1]: + post = self.reduction(torch.bmm(target_s, source_x),dim=0) + self.connection.w += self.nu[1] * post.view(self.connection.w.size()) + + super().update() + + def _local_connection2d_update(self, **kwargs) -> None: + # language=rst + """ + Post-pre learning rule for ``LocalConnection2D`` subclass of + ``AbstractConnection`` class. + """ + # Get LC layer parameters. + stride = self.connection.stride + batch_size = self.source.batch_size + kernel_height = self.connection.kernel_size[0] + kernel_width = self.connection.kernel_size[1] + in_channels = self.connection.source.shape[0] + out_channels = self.connection.n_filters + height_out = self.connection.conv_size[0] + width_out = self.connection.conv_size[1] + + target_x = self.target.x.reshape(batch_size, out_channels * height_out * width_out, 1) + target_x = target_x * torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) + source_s = self.source.s.type(torch.float).unfold(-2, kernel_height,stride[0]).unfold(-2, kernel_width, stride[1]).reshape( + batch_size, + height_out*width_out, + in_channels * kernel_height * kernel_width, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out*width_out,1) + target_s = target_s * torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) + source_x = self.source.x.unfold(-2, kernel_height,stride[0]).unfold(-2, kernel_width, stride[1]).reshape( + batch_size, + height_out*width_out, + in_channels * kernel_height * kernel_width, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + # Pre-synaptic update. + if self.nu[0]: + pre = self.reduction(torch.bmm(target_x,source_s), dim=0) + self.connection.w -= self.nu[0] * pre.view(self.connection.w.size()) + # Post-synaptic update. + if self.nu[1]: + post = self.reduction(torch.bmm(target_s, source_x),dim=0) + self.connection.w += self.nu[1] * post.view(self.connection.w.size()) + + super().update() + + def _local_connection3d_update(self, **kwargs) -> None: + # language=rst + """ + Post-pre learning rule for ``LocalConnection3D`` subclass of + ``AbstractConnection`` class. + """ + # Get LC layer parameters. + stride = self.connection.stride + batch_size = self.source.batch_size + kernel_height = self.connection.kernel_size[0] + kernel_width = self.connection.kernel_size[1] + kernel_depth = self.connection.kernel_size[2] + in_channels = self.connection.source.shape[0] + out_channels = self.connection.n_filters + height_out = self.connection.conv_size[0] + width_out = self.connection.conv_size[1] + depth_out = self.connection.conv_size[2] + + target_x = self.target.x.reshape(batch_size, out_channels * height_out * width_out * depth_out, 1) + target_x = target_x * torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) + source_s = self.source.s.type(torch.float).unfold(-3, kernel_height,stride[0]).unfold( + -3, kernel_width, stride[1]).unfold(-3, kernel_depth,stride[2]).reshape( + batch_size, + height_out*width_out*depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out*width_out*depth_out,1) + target_s = target_s * torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) + source_x = self.source.x.unfold(-3, kernel_height,stride[0]).unfold( + -3, kernel_width, stride[1]).unfold(-3, kernel_depth, stride[2]).reshape( + batch_size, + height_out*width_out*depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + # Pre-synaptic update. + if self.nu[0]: + pre = self.reduction(torch.bmm(target_x,source_s), dim=0) + self.connection.w -= self.nu[0] * pre.view(self.connection.w.size()) + # Post-synaptic update. + if self.nu[1]: + post = self.reduction(torch.bmm(target_s, source_x),dim=0) + self.connection.w += self.nu[1] * post.view(self.connection.w.size()) + + super().update() def _connection_update(self, **kwargs) -> None: # language=rst @@ -337,6 +503,7 @@ def _conv3d_connection_update(self, **kwargs) -> None: ) ) target_s = self.target.s.view(batch_size, out_channels, -1).float() + print(target_x.shape, source_s.shape, self.connection.w.shape) # Pre-synaptic update. if self.nu[0]: @@ -396,12 +563,18 @@ def __init__( if isinstance(connection, (Connection, LocalConnection)): self.update = self._connection_update + elif isinstance(connection, LocalConnection1D): + self.update = self._local_connection1d_update + elif isinstance(connection, LocalConnection2D): + self.update = self._local_connection2d_update + elif isinstance(connection, LocalConnection3D): + self.update = self._local_connection3d_update elif isinstance(connection, Conv1dConnection): self.update = self._conv1d_connection_update elif isinstance(connection, Conv2dConnection): self.update = self._conv2d_connection_update - elif isinstance(connection, Conv1dConnection): - self.update = self._conv1d_connection_update + elif isinstance(connection, Conv3dConnection): + self.update = self._conv3d_connection_update else: raise NotImplementedError( "This learning rule is not supported for this Connection type." @@ -436,6 +609,175 @@ def _connection_update(self, **kwargs) -> None: super().update() + def _local_connection1d_update(self, **kwargs) -> None: + # language=rst + """ + Weight-dependent post-pre learning rule for ``LocalConnection1D`` subclass of + ``AbstractConnection`` class. + """ + # Get LC layer parameters. + stride = self.connection.stride + batch_size = self.source.batch_size + kernel_height = self.connection.kernel_size + in_channels = self.connection.source.shape[0] + out_channels = self.connection.n_filters + height_out = self.connection.conv_size + + + target_x = self.target.x.reshape(batch_size, out_channels * height_out, 1) + target_x = target_x * torch.eye(out_channels * height_out).to(self.connection.w.device) + source_s = self.source.s.type(torch.float).unfold(-1, kernel_height, stride).reshape( + batch_size, + height_out, + in_channels * kernel_height, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + + target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out,1) + target_s = target_s * torch.eye(out_channels * height_out).to(self.connection.w.device) + source_x = self.source.x.unfold(-1, kernel_height, stride).reshape( + batch_size, + height_out, + in_channels * kernel_height, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + update = 0 + + # Pre-synaptic update. + if self.nu[0]: + pre = self.reduction(torch.bmm(target_x,source_s), dim=0) + update -= self.nu[0] * pre.view(self.connection.w.size()) * (self.connection.w - self.wmin) + # Post-synaptic update. + if self.nu[1]: + post = self.reduction(torch.bmm(target_s, source_x),dim=0) + update += self.nu[1] * post.view(self.connection.w.size()) * (self.wmax - self.connection.w) + + self.connection.w += update + + super().update() + + def _local_connection2d_update(self, **kwargs) -> None: + # language=rst + """ + Weight-dependent post-pre learning rule for ``LocalConnection2D`` subclass of + ``AbstractConnection`` class. + """ + # Get LC layer parameters. + stride = self.connection.stride + batch_size = self.source.batch_size + kernel_height = self.connection.kernel_size[0] + kernel_width = self.connection.kernel_size[1] + in_channels = self.connection.source.shape[0] + out_channels = self.connection.n_filters + height_out = self.connection.conv_size[0] + width_out = self.connection.conv_size[1] + + target_x = self.target.x.reshape(batch_size, out_channels * height_out * width_out, 1) + target_x = target_x * torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) + source_s = self.source.s.type(torch.float).unfold(-2, kernel_height,stride[0]).unfold(-2, kernel_width, stride[1]).reshape( + batch_size, + height_out*width_out, + in_channels * kernel_height * kernel_width, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out*width_out,1) + target_s = target_s * torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) + source_x = self.source.x.unfold(-2, kernel_height,stride[0]).unfold(-2, kernel_width, stride[1]).reshape( + batch_size, + height_out*width_out, + in_channels * kernel_height * kernel_width, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + update = 0 + + # Pre-synaptic update. + if self.nu[0]: + pre = self.reduction(torch.bmm(target_x,source_s), dim=0) + update -= self.nu[0] * pre.view(self.connection.w.size()) * (self.connection.w - self.wmin) + # Post-synaptic update. + if self.nu[1]: + post = self.reduction(torch.bmm(target_s, source_x),dim=0) + update += self.nu[1] * post.view(self.connection.w.size()) * (self.wmax - self.connection.w) + + self.connection.w += update + + super().update() + + def _local_connection3d_update(self, **kwargs) -> None: + # language=rst + """ + Weight-dependent post-pre learning rule for ``LocalConnection3D`` subclass of + ``AbstractConnection`` class. + """ + # Get LC layer parameters. + stride = self.connection.stride + batch_size = self.source.batch_size + kernel_height = self.connection.kernel_size[0] + kernel_width = self.connection.kernel_size[1] + kernel_depth = self.connection.kernel_size[2] + in_channels = self.connection.source.shape[0] + out_channels = self.connection.n_filters + height_out = self.connection.conv_size[0] + width_out = self.connection.conv_size[1] + depth_out = self.connection.conv_size[2] + + target_x = self.target.x.reshape(batch_size, out_channels * height_out * width_out * depth_out, 1) + target_x = target_x * torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) + source_s = self.source.s.type(torch.float).unfold(-3, kernel_height,stride[0]).unfold( + -3, kernel_width, stride[1]).unfold(-3, kernel_depth,stride[2]).reshape( + batch_size, + height_out*width_out*depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out*width_out*depth_out,1) + target_s = target_s * torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) + source_x = self.source.x.unfold(-3, kernel_height,stride[0]).unfold( + -3, kernel_width, stride[1]).unfold(-3, kernel_depth, stride[2]).reshape( + batch_size, + height_out*width_out*depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + update = 0 + + # Pre-synaptic update. + if self.nu[0]: + pre = self.reduction(torch.bmm(target_x,source_s), dim=0) + update -= self.nu[0] * pre.view(self.connection.w.size()) * (self.connection.w - self.wmin) + # Post-synaptic update. + if self.nu[1]: + post = self.reduction(torch.bmm(target_s, source_x),dim=0) + update += self.nu[1] * post.view(self.connection.w.size()) * (self.wmax - self.connection.w) + + self.connection.w += update + + super().update() + def _conv1d_connection_update(self, **kwargs) -> None: # language=rst """ @@ -487,6 +829,8 @@ def _conv1d_connection_update(self, **kwargs) -> None: super().update() + + def _conv2d_connection_update(self, **kwargs) -> None: # language=rst """ @@ -658,6 +1002,12 @@ def __init__( if isinstance(connection, (Connection, LocalConnection)): self.update = self._connection_update + elif isinstance(connection, LocalConnection1D): + self.update = self._local_connection1d_update + elif isinstance(connection, LocalConnection2D): + self.update = self._local_connection2d_update + elif isinstance(connection, LocalConnection3D): + self.update = self._local_connection3d_update elif isinstance(connection, Conv1dConnection): self.update = self._conv1d_connection_update elif isinstance(connection, Conv2dConnection): @@ -692,6 +1042,160 @@ def _connection_update(self, **kwargs) -> None: super().update() + def _local_connection1d_update(self, **kwargs) -> None: + # language=rst + """ + Hebbian learning rule for ``LocalConnection1D`` subclass of + ``AbstractConnection`` class. + """ + # Get LC layer parameters. + stride = self.connection.stride + batch_size = self.source.batch_size + kernel_height = self.connection.kernel_size + in_channels = self.connection.source.shape[0] + out_channels = self.connection.n_filters + height_out = self.connection.conv_size + + + target_x = self.target.x.reshape(batch_size, out_channels * height_out, 1) + target_x = target_x * torch.eye(out_channels * height_out).to(self.connection.w.device) + source_s = self.source.s.type(torch.float).unfold(-1, kernel_height, stride).reshape( + batch_size, + height_out, + in_channels * kernel_height, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + + target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out,1) + target_s = target_s * torch.eye(out_channels * height_out).to(self.connection.w.device) + source_x = self.source.x.unfold(-1, kernel_height, stride).reshape( + batch_size, + height_out, + in_channels * kernel_height, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + # Pre-synaptic update. + pre = self.reduction(torch.bmm(target_x,source_s), dim=0) + self.connection.w += self.nu[0] * pre.view(self.connection.w.size()) + + # Post-synaptic update. + post = self.reduction(torch.bmm(target_s, source_x),dim=0) + self.connection.w += self.nu[1] * post.view(self.connection.w.size()) + + super().update() + + def _local_connection2d_update(self, **kwargs) -> None: + # language=rst + """ + Hebbian learning rule for ``LocalConnection2D`` subclass of + ``AbstractConnection`` class. + """ + # Get LC layer parameters. + stride = self.connection.stride + batch_size = self.source.batch_size + kernel_height = self.connection.kernel_size[0] + kernel_width = self.connection.kernel_size[1] + in_channels = self.connection.source.shape[0] + out_channels = self.connection.n_filters + height_out = self.connection.conv_size[0] + width_out = self.connection.conv_size[1] + + target_x = self.target.x.reshape(batch_size, out_channels * height_out * width_out, 1) + target_x = target_x * torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) + source_s = self.source.s.type(torch.float).unfold(-2, kernel_height,stride[0]).unfold(-2, kernel_width, stride[1]).reshape( + batch_size, + height_out*width_out, + in_channels * kernel_height * kernel_width, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out*width_out,1) + target_s = target_s * torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) + source_x = self.source.x.unfold(-2, kernel_height,stride[0]).unfold(-2, kernel_width, stride[1]).reshape( + batch_size, + height_out*width_out, + in_channels * kernel_height * kernel_width, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + # Pre-synaptic update. + pre = self.reduction(torch.bmm(target_x,source_s), dim=0) + self.connection.w += self.nu[0] * pre.view(self.connection.w.size()) + + # Post-synaptic update. + post = self.reduction(torch.bmm(target_s, source_x),dim=0) + self.connection.w += self.nu[1] * post.view(self.connection.w.size()) + + super().update() + + def _local_connection3d_update(self, **kwargs) -> None: + # language=rst + """ + Post-pre learning rule for ``LocalConnection3D`` subclass of + ``AbstractConnection`` class. + """ + # Get LC layer parameters. + stride = self.connection.stride + batch_size = self.source.batch_size + kernel_height = self.connection.kernel_size[0] + kernel_width = self.connection.kernel_size[1] + kernel_depth = self.connection.kernel_size[2] + in_channels = self.connection.source.shape[0] + out_channels = self.connection.n_filters + height_out = self.connection.conv_size[0] + width_out = self.connection.conv_size[1] + depth_out = self.connection.conv_size[2] + + target_x = self.target.x.reshape(batch_size, out_channels * height_out * width_out * depth_out, 1) + target_x = target_x * torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) + source_s = self.source.s.type(torch.float).unfold(-3, kernel_height,stride[0]).unfold( + -3, kernel_width, stride[1]).unfold(-3, kernel_depth,stride[2]).reshape( + batch_size, + height_out*width_out*depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out*width_out*depth_out,1) + target_s = target_s * torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) + source_x = self.source.x.unfold(-3, kernel_height,stride[0]).unfold( + -3, kernel_width, stride[1]).unfold(-3, kernel_depth, stride[2]).reshape( + batch_size, + height_out*width_out*depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + # Pre-synaptic update. + pre = self.reduction(torch.bmm(target_x,source_s), dim=0) + self.connection.w += self.nu[0] * pre.view(self.connection.w.size()) + + # Post-synaptic update. + post = self.reduction(torch.bmm(target_s, source_x),dim=0) + self.connection.w += self.nu[1] * post.view(self.connection.w.size()) + + super().update() + def _conv1d_connection_update(self, **kwargs) -> None: # language=rst """ @@ -865,6 +1369,12 @@ def __init__( self.update = self._conv2d_connection_update elif isinstance(connection, Conv3dConnection): self.update = self._conv3d_connection_update + elif isinstance(connection, LocalConnection1D): + self.update = self._local_connection1d_update + elif isinstance(connection, LocalConnection2D): + self.update = self._local_connection2d_update + elif isinstance(connection, LocalConnection3D): + self.update = self._local_connection3d_update else: raise NotImplementedError( "This learning rule is not supported for this Connection type." @@ -937,6 +1447,280 @@ def _connection_update(self, **kwargs) -> None: super().update() + def _local_connection1d_update(self, **kwargs) -> None: + # language=rst + """ + MSTDP learning rule for ``LocalConnection1D`` subclass of + ``AbstractConnection`` class. + """ + + # Get LC layer parameters. + stride = self.connection.stride + batch_size = self.source.batch_size + kernel_height = self.connection.kernel_size + in_channels = self.connection.in_channels + out_channels = self.connection.n_filters + height_out = self.connection.conv_size + + # Initialize eligibility. + if not hasattr(self, "eligibility"): + self.eligibility = torch.zeros( + batch_size, *self.connection.w.shape, device=self.connection.w.device + ) + + # Parse keyword arguments. + reward = kwargs["reward"] + a_plus = torch.tensor( + kwargs.get("a_plus", 1.0), device=self.connection.w.device + ) + a_minus = torch.tensor( + kwargs.get("a_minus", -1.0), device=self.connection.w.device + ) + + # Compute weight update based on the eligibility value of the past timestep. + update = reward * self.eligibility + self.connection.w += self.nu[0] * self.reduction(update, dim=0) + + # Initialize P^+ and P^-. + if not hasattr(self, "p_plus"): + self.p_plus = torch.zeros( + batch_size, *self.source.shape, device=self.connection.w.device + ) + self.p_plus = self.p_plus.unfold(-1, kernel_height,stride).reshape( + batch_size, + height_out, + in_channels * kernel_height, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + if not hasattr(self, "p_minus"): + self.p_minus = torch.zeros( + batch_size, *self.target.shape, device=self.connection.w.device + ) + self.p_minus = self.p_minus.reshape(batch_size,\ + out_channels * height_out, 1) + self.p_minus = self.p_minus *\ + torch.eye(out_channels * height_out).to(self.connection.w.device) + + # Reshaping spike occurrences. + source_s = self.source.s.type(torch.float).unfold(-1, kernel_height, stride).reshape( + batch_size, + height_out, + in_channels*kernel_height, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out,1) + target_s = target_s * torch.eye(out_channels*height_out).to(self.connection.w.device) + + # Update P^+ and P^- values. + self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) + self.p_plus += a_plus * source_s + self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus) + self.p_minus += a_minus * target_s + + # Calculate point eligibility value. + self.eligibility = torch.bmm( + target_s, self.p_plus + ) + torch.bmm(self.p_minus, source_s) + + self.eligibility = self.eligibility.view(batch_size, *self.connection.w.shape) + + super().update() + + + def _local_connection2d_update(self, **kwargs) -> None: + # language=rst + """ + MSTDP learning rule for ``LocalConnection2D`` subclass of + ``AbstractConnection`` class. + """ + # Get LC layer parameters. + stride = self.connection.stride + batch_size = self.source.batch_size + kernel_height = self.connection.kernel_size[0] + kernel_width = self.connection.kernel_size[1] + in_channels = self.connection.in_channels + out_channels = self.connection.n_filters + height_out = self.connection.conv_size[0] + width_out = self.connection.conv_size[1] + + + # Initialize eligibility. + if not hasattr(self, "eligibility"): + self.eligibility = torch.zeros( + batch_size, *self.connection.w.shape, device=self.connection.w.device + ) + + # Parse keyword arguments. + reward = kwargs["reward"] + a_plus = torch.tensor( + kwargs.get("a_plus", 1.0), device=self.connection.w.device + ) + a_minus = torch.tensor( + kwargs.get("a_minus", -1.0), device=self.connection.w.device + ) + + # Compute weight update based on the eligibility value of the past timestep. + update = reward * self.eligibility + + self.connection.w += self.nu[0] * self.reduction(update, dim=0) + + # Initialize P^+ and P^-. + if not hasattr(self, "p_plus"): + self.p_plus = torch.zeros( + batch_size, *self.source.shape, device=self.connection.w.device + ) + self.p_plus = self.p_plus.unfold(-2, kernel_height,stride[0]).unfold(-2, kernel_width, stride[1]).reshape( + batch_size, + height_out * width_out, + in_channels * kernel_height * kernel_width, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + if not hasattr(self, "p_minus"): + self.p_minus = torch.zeros( + batch_size, *self.target.shape, device=self.connection.w.device + ) + self.p_minus = self.p_minus.reshape(batch_size,\ + out_channels * height_out * width_out, 1) + self.p_minus = self.p_minus *\ + torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) + + # Reshaping spike occurrences. + source_s = self.source.s.type(torch.float).unfold(-2, kernel_height, stride[0]).unfold(-2, kernel_width, stride[1]).reshape( + batch_size, + height_out*width_out, + in_channels*kernel_height*kernel_width, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out*width_out, 1) + target_s = target_s * torch.eye(out_channels*height_out*width_out).to(self.connection.w.device) + + # Update P^+ and P^- values. + self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) + self.p_plus += a_plus * source_s + self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus) + self.p_minus += a_minus * target_s + + # Calculate point eligibility value. + self.eligibility = torch.bmm( + target_s, self.p_plus + ) + torch.bmm(self.p_minus, source_s) + + self.eligibility = self.eligibility.view(batch_size, *self.connection.w.shape) + + super().update() + + + def _local_connection3d_update(self, **kwargs) -> None: + # language=rst + """ + MSTDP learning rule for ``LocalConnection3D`` subclass of + ``AbstractConnection`` class. + """ + + # Get LC layer parameters. + stride = self.connection.stride + batch_size = self.source.batch_size + kernel_height = self.connection.kernel_size[0] + kernel_width = self.connection.kernel_size[1] + kernel_depth = self.connection.kernel_size[2] + in_channels = self.connection.in_channels + out_channels = self.connection.n_filters + height_out = self.connection.conv_size[0] + width_out = self.connection.conv_size[1] + depth_out = self.connection.conv_size[2] + + + # Initialize eligibility. + if not hasattr(self, "eligibility"): + self.eligibility = torch.zeros( + batch_size, *self.connection.w.shape, device=self.connection.w.device + ) + + # Parse keyword arguments. + reward = kwargs["reward"] + a_plus = torch.tensor( + kwargs.get("a_plus", 1.0), device=self.connection.w.device + ) + a_minus = torch.tensor( + kwargs.get("a_minus", -1.0), device=self.connection.w.device + ) + + # Compute weight update based on the eligibility value of the past timestep. + update = reward * self.eligibility + self.connection.w += self.nu[0] * self.reduction(update, dim=0) + + # Initialize P^+ and P^-. + if not hasattr(self, "p_plus"): + self.p_plus = torch.zeros( + batch_size, *self.source.shape, device=self.connection.w.device + ) + self.p_plus = self.p_plus.unfold(-3, kernel_height,stride[0]).unfold( + -3, kernel_width, stride[1]).unfold(-3, kernel_depth,stride[2]).reshape( + batch_size, + height_out * width_out * depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + if not hasattr(self, "p_minus"): + self.p_minus = torch.zeros( + batch_size, *self.target.shape, device=self.connection.w.device + ) + self.p_minus = self.p_minus.reshape(batch_size,\ + out_channels * height_out * width_out * depth_out, 1) + self.p_minus = self.p_minus *\ + torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) + + # Reshaping spike occurrences. + source_s = self.source.s.type(torch.float).unfold(-3, kernel_height, stride[0]).unfold( + -3, kernel_width, stride[1]).unfold(-3, kernel_depth, stride[2]).reshape( + batch_size, + height_out*width_out*depth_out, + in_channels*kernel_height*kernel_width*kernel_depth, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out*width_out*depth_out,1) + target_s = target_s * torch.eye(out_channels*height_out*width_out*depth_out).to(self.connection.w.device) + + # Update P^+ and P^- values. + self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) + self.p_plus += a_plus * source_s + self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus) + self.p_minus += a_minus * target_s + + # Calculate point eligibility value. + self.eligibility = torch.bmm( + target_s, self.p_plus + ) + torch.bmm(self.p_minus, source_s) + + self.eligibility = self.eligibility.view(batch_size, *self.connection.w.shape) + + super().update() + + def _conv1d_connection_update(self, **kwargs) -> None: # language=rst """ @@ -1236,6 +2020,12 @@ def __init__( if isinstance(connection, (Connection, LocalConnection)): self.update = self._connection_update + elif isinstance(connection, LocalConnection1D): + self.update = self._local_connection1d_update + elif isinstance(connection, LocalConnection2D): + self.update = self._local_connection2d_update + elif isinstance(connection, LocalConnection3D): + self.update = self._local_connection3d_update elif isinstance(connection, Conv1dConnection): self.update = self._conv1d_connection_update elif isinstance(connection, Conv2dConnection): @@ -1314,6 +2104,305 @@ def _connection_update(self, **kwargs) -> None: super().update() + def _local_connection1d_update(self, **kwargs) -> None: + # language=rst + """ + MSTDPET learning rule for ``LocalConnection1D`` subclass of + ``AbstractConnection`` class. + """ + + # Get LC layer parameters. + + stride = self.connection.stride + batch_size = self.source.batch_size + kernel_height = self.connection.kernel_size + in_channels = self.connection.in_channels + out_channels = self.connection.n_filters + height_out = self.connection.conv_size + + # Initialize eligibility. + if not hasattr(self, "eligibility"): + self.eligibility = torch.zeros( + batch_size, *self.connection.w.shape, device=self.connection.w.device + ) + + if not hasattr(self, "eligibility_trace"): + self.eligibility_trace = torch.zeros( + *self.connection.w.shape, device=self.connection.w.device + ) + + # Parse keyword arguments. + reward = kwargs["reward"] + a_plus = torch.tensor( + kwargs.get("a_plus", 1.0), device=self.connection.w.device + ) + a_minus = torch.tensor( + kwargs.get("a_minus", -1.0), device=self.connection.w.device + ) + + # Calculate value of eligibility trace based on the value + # of the point eligibility value of the past timestep. + self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace) + + # Compute weight update. + update = reward * self.eligibility_trace + self.connection.w += self.nu[0] * self.connection.dt * torch.sum(update, dim=0) + + # Initialize P^+ and P^-. + if not hasattr(self, "p_plus"): + self.p_plus = torch.zeros( + batch_size, *self.source.shape, device=self.connection.w.device + ) + self.p_plus = self.p_plus.unfold(-1, kernel_height, stride).reshape( + batch_size, + height_out, + in_channels * kernel_height, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + if not hasattr(self, "p_minus"): + self.p_minus = torch.zeros( + batch_size, *self.target.shape, device=self.connection.w.device + ) + self.p_minus = self.p_minus.reshape(batch_size,\ + out_channels * height_out, 1) + self.p_minus = self.p_minus *\ + torch.eye(out_channels * height_out).to(self.connection.w.device) + + # Reshaping spike occurrences. + source_s = self.source.s.type(torch.float).unfold(-1, kernel_height,stride).reshape( + batch_size, + height_out, + in_channels * kernel_height, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + # print(target_x.shape, source_s.shape) + target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels * height_out,1) + target_s = target_s * torch.eye(out_channels * height_out).to(self.connection.w.device) + + # Update P^+ and P^- values. + self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) + self.p_plus += a_plus * source_s + self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus) + self.p_minus += a_minus * target_s + + # Calculate point eligibility value. + self.eligibility = torch.bmm( + target_s, self.p_plus + ) + torch.bmm(self.p_minus, source_s) + self.eligibility = self.eligibility.view(batch_size, *self.connection.w.shape) + + super().update() + + def _local_connection2d_update(self, **kwargs) -> None: + # language=rst + """ + MSTDPET learning rule for ``LocalConnection2D`` subclass of + ``AbstractConnection`` class. + """ + # Get LC layer parameters. + + stride = self.connection.stride + batch_size = self.source.batch_size + kernel_width = self.connection.kernel_size[0] + kernel_height = self.connection.kernel_size[1] + in_channels = self.connection.in_channels + out_channels = self.connection.n_filters + height_out = self.connection.conv_size[0] + width_out = self.connection.conv_size[1] + + + + # Initialize eligibility. + if not hasattr(self, "eligibility"): + self.eligibility = torch.zeros( + batch_size, *self.connection.w.shape, device=self.connection.w.device + ) + + if not hasattr(self, "eligibility_trace"): + self.eligibility_trace = torch.zeros( + *self.connection.w.shape, device=self.connection.w.device + ) + + # Parse keyword arguments. + reward = kwargs["reward"] + a_plus = torch.tensor( + kwargs.get("a_plus", 1.0), device=self.connection.w.device + ) + a_minus = torch.tensor( + kwargs.get("a_minus", -1.0), device=self.connection.w.device + ) + + # Calculate value of eligibility trace based on the value + # of the point eligibility value of the past timestep. + self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace) + + # Compute weight update. + update = reward * self.eligibility_trace + self.connection.w += self.nu[0] * self.connection.dt * torch.sum(update, dim=0) + + # Initialize P^+ and P^-. + if not hasattr(self, "p_plus"): + self.p_plus = torch.zeros( + batch_size, *self.source.shape, device=self.connection.w.device + ) + self.p_plus = self.p_plus.unfold(-2, kernel_height, stride[0]).unfold(-2, kernel_width, stride[1]).reshape( + batch_size, + height_out*width_out, + in_channels * kernel_height * kernel_width, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + if not hasattr(self, "p_minus"): + self.p_minus = torch.zeros( + batch_size, *self.target.shape, device=self.connection.w.device + ) + self.p_minus = self.p_minus.reshape(batch_size,\ + out_channels * height_out * width_out, 1) + self.p_minus = self.p_minus *\ + torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) + + # Reshaping spike occurrences. + source_s = self.source.s.type(torch.float).unfold(-2, kernel_height,stride[0]).unfold(-2, kernel_width, stride[1]).reshape( + batch_size, + height_out*width_out, + in_channels * kernel_height * kernel_width, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + # print(target_x.shape, source_s.shape) + target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels * height_out * width_out,1) + target_s = target_s * torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) + + # Update P^+ and P^- values. + self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) + self.p_plus += a_plus * source_s + self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus) + self.p_minus += a_minus * target_s + + # Calculate point eligibility value. + self.eligibility = torch.bmm( + target_s, self.p_plus + ) + torch.bmm(self.p_minus, source_s) + self.eligibility = self.eligibility.view(batch_size, *self.connection.w.shape) + + super().update() + + def _local_connection3d_update(self, **kwargs) -> None: + # language=rst + """ + MSTDPET learning rule for ``LocalConnection3D`` subclass of + ``AbstractConnection`` class. + """ + + # Get LC layer parameters. + stride = self.connection.stride + batch_size = self.source.batch_size + kernel_width = self.connection.kernel_size[0] + kernel_height = self.connection.kernel_size[1] + kernel_depth = self.connection.kernel_size[2] + in_channels = self.connection.in_channels + out_channels = self.connection.n_filters + height_out = self.connection.conv_size[0] + width_out = self.connection.conv_size[1] + depth_out = self.connection.conv_size[2] + + # Initialize eligibility. + if not hasattr(self, "eligibility"): + self.eligibility = torch.zeros( + batch_size, *self.connection.w.shape, device=self.connection.w.device + ) + + if not hasattr(self, "eligibility_trace"): + self.eligibility_trace = torch.zeros( + *self.connection.w.shape, device=self.connection.w.device + ) + + # Parse keyword arguments. + reward = kwargs["reward"] + a_plus = torch.tensor( + kwargs.get("a_plus", 1.0), device=self.connection.w.device + ) + a_minus = torch.tensor( + kwargs.get("a_minus", -1.0), device=self.connection.w.device + ) + + # Calculate value of eligibility trace based on the value + # of the point eligibility value of the past timestep. + self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace) + + # Compute weight update. + update = reward * self.eligibility_trace + self.connection.w += self.nu[0] * self.connection.dt * torch.sum(update, dim=0) + + # Initialize P^+ and P^-. + if not hasattr(self, "p_plus"): + self.p_plus = torch.zeros( + batch_size, *self.source.shape, device=self.connection.w.device + ) + self.p_plus = self.p_plus.unfold(-3, kernel_height, stride[0]).unfold( + -3, kernel_width, stride[1]).unfold(-3, kernel_depth, stride[2]).reshape( + batch_size, + height_out*width_out*depth_out, + in_channels * kernel_height * kernel_width*kernel_depth, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + if not hasattr(self, "p_minus"): + self.p_minus = torch.zeros( + batch_size, *self.target.shape, device=self.connection.w.device + ) + self.p_minus = self.p_minus.reshape(batch_size,\ + out_channels * height_out * width_out * depth_out, 1) + self.p_minus = self.p_minus *\ + torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) + + # Reshaping spike occurrences. + source_s = self.source.s.type(torch.float).unfold(-3, kernel_height,stride[0]).unfold( + -3, kernel_width, stride[1]).unfold(-3, kernel_depth, stride[2]).reshape( + batch_size, + height_out*width_out*depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ).repeat( + 1, + out_channels, + 1, + ).to(self.connection.w.device) + + # print(target_x.shape, source_s.shape) + target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels * height_out * width_out * depth_out,1) + target_s = target_s * torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) + + # Update P^+ and P^- values. + self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) + self.p_plus += a_plus * source_s + self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus) + self.p_minus += a_minus * target_s + + # Calculate point eligibility value. + self.eligibility = torch.bmm( + target_s, self.p_plus + ) + torch.bmm(self.p_minus, source_s) + self.eligibility = self.eligibility.view(batch_size, *self.connection.w.shape) + + super().update() + def _conv1d_connection_update(self, **kwargs) -> None: # language=rst """ diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index 24b46fda..ff3f57b6 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -4,6 +4,7 @@ import numpy as np import torch import torch.nn.functional as F +from bindsnet.utils import im2col_indices from torch.nn import Module, Parameter from torch.nn.modules.utils import _pair, _triple @@ -252,7 +253,6 @@ def reset_state_variables(self) -> None: """ super().reset_state_variables() - class Conv1dConnection(AbstractConnection): # language=rst """ @@ -398,7 +398,6 @@ def reset_state_variables(self) -> None: """ super().reset_state_variables() - class Conv2dConnection(AbstractConnection): # language=rst """ @@ -1013,8 +1012,8 @@ def __init__( ) -> None: # language=rst """ - Instantiates a ``LocalConnection`` object. Source population should be - two-dimensional. + Instantiates a ``LocalConnection2D`` object. Source population should have + square size Neurons in the post-synaptic population are ordered by receptive field; that is, if there are ``n_conv`` neurons in each post-synaptic patch, then the first @@ -1045,6 +1044,7 @@ def __init__( :param Tuple[int, int] input_shape: Shape of input population if it's not ``[sqrt, sqrt]``. """ + super().__init__(source, target, nu, reduction, weight_decay, **kwargs) kernel_size = _pair(kernel_size) @@ -1169,7 +1169,483 @@ def reset_state_variables(self) -> None: """ super().reset_state_variables() +class LocalConnection1D(AbstractConnection): + """ + Specifies a one-dimensional local connection between one or two population of neurons supporting multi-channel inputs with shape (C, H); + The logic is different from the original LocalConnection implementation (where masks were used with normal dense connections). + """ + + def __init__( + self, + source: Nodes, + target: Nodes, + kernel_size: int, + stride: int, + n_filters: int, + nu: Optional[Union[float, Sequence[float]]] = None, + reduction: Optional[callable] = None, + weight_decay: float = 0.0, + **kwargs + ) -> None: + """ + Instantiates a 'LocalConnection1D` object. Source population can be multi-channel. + Neurons in the post-synaptic population are ordered by receptive field, i.e., + if there are `n_conv` neurons in each post-synaptic patch, then the first + `n_conv` neurons in the post-synaptic population correspond to the first + receptive field, the second ``n_conv`` to the second receptive field, and so on. + :param source: A layer of nodes from which the connection originates. + :param target: A layer of nodes to which the connection connects. + :param kernel_size: size of convolutional kernels. + :param stride: stride for convolution. + :param n_filters: Number of locally connected filters per pre-synaptic region. + :param nu: Learning rate for both pre- and post-synaptic events. + :param reduction: Method for reducing parameter updates along the minibatch dimension. + :param weight_decay: Constant multiple to decay weights by on each iteration. + Keyword arguments: + :param LearningRule update_rule: Modifies connection parameters according to some rule. + :param torch.Tensor w: Strengths of synapses. + :param torch.Tensor b: Target population bias. + :param float wmin: Minimum allowed value on the connection weights. + :param float wmax: Maximum allowed value on the connection weights. + :param float norm: Total weight per target neuron normalization constant. + """ + + super().__init__(source, target, nu, reduction, weight_decay, **kwargs) + + + self.kernel_size = kernel_size + self.stride = stride + self.n_filters = n_filters + + self.in_channels, input_height = ( + source.shape[0], + source.shape[1], + ) + + height = int(( + input_height - self.kernel_size + ) / self.stride) + 1 + + + self.conv_size = height + + w = kwargs.get("w", None) + + error = ( + "Target dimensionality must be (in_channels," + "n_filters*conv_size," + "kernel_size)" + ) + + if w is None: + w = torch.rand( + self.in_channels, + self.n_filters * self.conv_size, + self.kernel_size + ) + else: + assert w.shape == ( + self.in_channels, + self.out_channels * self.conv_size, + self.kernel_size + ), error + + if self.wmin != -np.inf or self.wmax != np.inf: + w = torch.clamp(w, self.wmin, self.wmax) + + self.w = Parameter(w, requires_grad=False) + self.b = Parameter(kwargs.get("b", None), requires_grad=False) + + + def compute(self, s: torch.Tensor) -> torch.Tensor: + """ + Compute pre-activations given spikes using layer weights. + :param s: Incoming spikes. + :return: Incoming spikes multiplied by synaptic weights (with or without + decaying spike activation). + """ + # Compute multiplication of pre-activations by connection weights + # s: batch, ch_in, h_in => s_unfold: batch, ch_in, ch_out * h_out, k + # w: ch_in, ch_out * h_out, k + # a_post: batch, ch_in, ch_out * h_out, k => batch, ch_out * h_out (= target.n) + + batch_size = s.shape[0] + + self.s_unfold = s.unfold( + -1,self.kernel_size,self.stride + ).reshape( + s.shape[0], + self.in_channels, + self.conv_size, + self.kernel_size, + ).repeat( + 1, + 1, + self.n_filters, + 1, + ) + + a_post = self.s_unfold.to(self.w.device) * self.w + + return a_post.sum(-1).sum(1).view(batch_size, *self.target.shape) + + def update(self, **kwargs) -> None: + """ + Compute connection's update rule. + """ + super().update(**kwargs) + + def normalize(self) -> None: + """ + Normalize weights so each target neuron has sum of connection weights equal to + ``self.norm``. + """ + if self.norm is not None: + # get a view and modify in-place + # w: ch_in, ch_out * h_out, k + w = self.w.view( + self.w.shape[0]*self.w.shape[1], self.w.shape[2] + ) + + for fltr in range(w.shape[0]): + w[fltr,:] *= self.norm / w[fltr,:].sum(0) + + + def reset_state_variables(self) -> None: + """ + Contains resetting logic for the connection. + """ + super().reset_state_variables() + + self.target.reset_state_variables() + +class LocalConnection2D(AbstractConnection): + """ + Specifies a two-dimensional local connection between one or two population of neurons supporting multi-channel inputs with shape (C, H, W); + The logic is different from the original LocalConnection implementation (where masks were used with normal dense connections) + """ + + def __init__( + self, + source: Nodes, + target: Nodes, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + n_filters: int, + nu: Optional[Union[float, Sequence[float]]] = None, + reduction: Optional[callable] = None, + weight_decay: float = 0.0, + **kwargs + ) -> None: + """ + Instantiates a 'LocalConnection2D` object. Source population can be multi-channel. + Neurons in the post-synaptic population are ordered by receptive field, i.e., + if there are `n_conv` neurons in each post-synaptic patch, then the first + `n_conv` neurons in the post-synaptic population correspond to the first + receptive field, the second ``n_conv`` to the second receptive field, and so on. + :param source: A layer of nodes from which the connection originates. + :param target: A layer of nodes to which the connection connects. + :param kernel_size: Horizontal and vertical size of convolutional kernels. + :param stride: Horizontal and vertical stride for convolution. + :param n_filters: Number of locally connected filters per pre-synaptic region. + :param nu: Learning rate for both pre- and post-synaptic events. + :param reduction: Method for reducing parameter updates along the minibatch dimension. + :param weight_decay: Constant multiple to decay weights by on each iteration. + Keyword arguments: + :param LearningRule update_rule: Modifies connection parameters according to some rule. + :param torch.Tensor w: Strengths of synapses. + :param torch.Tensor b: Target population bias. + :param float wmin: Minimum allowed value on the connection weights. + :param float wmax: Maximum allowed value on the connection weights. + :param float norm: Total weight per target neuron normalization constant. + """ + + super().__init__(source, target, nu, reduction, weight_decay, **kwargs) + + kernel_size = _pair(kernel_size) + stride = _pair(stride) + + self.kernel_size = kernel_size + self.stride = stride + self.n_filters = n_filters + + self.in_channels, input_height, input_width = ( + source.shape[0], + source.shape[1], + source.shape[2], + ) + + height = int(( + input_height - self.kernel_size[0] + ) / self.stride[0]) + 1 + width = int(( + input_width - self.kernel_size[1] + ) / self.stride[1]) + 1 + + + self.conv_size = (height, width) + self.conv_prod = int(np.prod(self.conv_size)) + self.kernel_prod = int(np.prod(kernel_size)) + + w = kwargs.get("w", None) + + error = ( + "Target dimensionality must be (in_channels," + "n_filters*conv_prod," + "kernel_prod)" + ) + + if w is None: + w = torch.rand( + self.in_channels, + self.n_filters * self.conv_prod, + self.kernel_prod + ) + else: + assert w.shape == ( + self.in_channels, + self.out_channels * self.conv_prod, + self.kernel_prod + ), error + + if self.wmin != -np.inf or self.wmax != np.inf: + w = torch.clamp(w, self.wmin, self.wmax) + + self.w = Parameter(w, requires_grad=False) + self.b = Parameter(kwargs.get("b", None), requires_grad=False) + + + def compute(self, s: torch.Tensor) -> torch.Tensor: + """ + Compute pre-activations given spikes using layer weights. + :param s: Incoming spikes. + :return: Incoming spikes multiplied by synaptic weights (with or without + decaying spike activation). + """ + # Compute multiplication of pre-activations by connection weights + # s: batch, ch_in, w_in, h_in => s_unfold: batch, ch_in, ch_out * w_out * h_out, k1*k2 + # w: ch_in, ch_out * w_out * h_out, k1*k2 + # a_post: batch, ch_in, ch_out * w_out * h_out, k1*k2 => batch, ch_out * w_out * h_out (= target.n) + + batch_size = s.shape[0] + + self.s_unfold = s.unfold( + -2,self.kernel_size[0],self.stride[0] + ).unfold( + -2,self.kernel_size[1],self.stride[1] + ).reshape( + s.shape[0], + self.in_channels, + self.conv_prod, + self.kernel_prod, + ).repeat( + 1, + 1, + self.n_filters, + 1, + ) + + a_post = self.s_unfold.to(self.w.device) * self.w + + return a_post.sum(-1).sum(1).view(batch_size, *self.target.shape) + + def update(self, **kwargs) -> None: + """ + Compute connection's update rule. + """ + super().update(**kwargs) + + def normalize(self) -> None: + """ + Normalize weights so each target neuron has sum of connection weights equal to + ``self.norm``. + """ + if self.norm is not None: + # get a view and modify in-place + # w: ch_in, ch_out * w_out * h_out, k1 * k2 + w = self.w.view( + self.w.shape[0]*self.w.shape[1], self.w.shape[2] + ) + + for fltr in range(w.shape[0]): + w[fltr,:] *= self.norm / w[fltr,:].sum(0) + + + def reset_state_variables(self) -> None: + """ + Contains resetting logic for the connection. + """ + super().reset_state_variables() + + self.target.reset_state_variables() + + +class LocalConnection3D(AbstractConnection): + """ + Specifies a three-dimensional local connection between one or two population of neurons supporting multi-channel inputs with shape (C, H, W, D); + The logic is different from the original LocalConnection implementation (where masks were used with normal dense connections) + """ + + def __init__( + self, + source: Nodes, + target: Nodes, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]], + n_filters: int, + nu: Optional[Union[float, Sequence[float]]] = None, + reduction: Optional[callable] = None, + weight_decay: float = 0.0, + **kwargs + ) -> None: + """ + Instantiates a 'LocalConnection3D` object. Source population can be multi-channel. + Neurons in the post-synaptic population are ordered by receptive field, i.e., + if there are `n_conv` neurons in each post-synaptic patch, then the first + `n_conv` neurons in the post-synaptic population correspond to the first + receptive field, the second ``n_conv`` to the second receptive field, and so on. + :param source: A layer of nodes from which the connection originates. + :param target: A layer of nodes to which the connection connects. + :param kernel_size: Horizontal, vertical, and depth-wise size of convolutional kernels. + :param stride: Horizontal, vertical, and depth-wise stride for convolution. + :param n_filters: Number of locally connected filters per pre-synaptic region. + :param nu: Learning rate for both pre- and post-synaptic events. + :param reduction: Method for reducing parameter updates along the minibatch dimension. + :param weight_decay: Constant multiple to decay weights by on each iteration. + Keyword arguments: + :param LearningRule update_rule: Modifies connection parameters according to some rule. + :param torch.Tensor w: Strengths of synapses. + :param torch.Tensor b: Target population bias. + :param float wmin: Minimum allowed value on the connection weights. + :param float wmax: Maximum allowed value on the connection weights. + :param float norm: Total weight per target neuron normalization constant. + """ + + super().__init__(source, target, nu, reduction, weight_decay, **kwargs) + + kernel_size = _triple(kernel_size) + stride = _triple(stride) + + self.kernel_size = kernel_size + self.stride = stride + self.n_filters = n_filters + + self.in_channels, input_height, input_width, input_depth = ( + source.shape[0], + source.shape[1], + source.shape[2], + source.shape[3] + ) + + height = int(( + input_height - self.kernel_size[0] + ) / self.stride[0]) + 1 + width = int(( + input_width - self.kernel_size[1] + ) / self.stride[1]) + 1 + depth = int(( + input_depth - self.kernel_size[2] + ) / self.stride[2]) + 1 + + + self.conv_size = (height, width, depth) + self.conv_prod = int(np.prod(self.conv_size)) + self.kernel_prod = int(np.prod(kernel_size)) + + w = kwargs.get("w", None) + + error = ( + "Target dimensionality must be (in_channels," + "n_filters*conv_prod," + "kernel_prod)" + ) + + if w is None: + w = torch.rand( + self.in_channels, + self.n_filters * self.conv_prod, + self.kernel_prod + ) + else: + assert w.shape == ( + self.in_channels, + self.out_channels * self.conv_prod, + self.kernel_prod + ), error + + if self.wmin != -np.inf or self.wmax != np.inf: + w = torch.clamp(w, self.wmin, self.wmax) + + self.w = Parameter(w, requires_grad=False) + self.b = Parameter(kwargs.get("b", None), requires_grad=False) + + + def compute(self, s: torch.Tensor) -> torch.Tensor: + """ + Compute pre-activations given spikes using layer weights. + :param s: Incoming spikes. + :return: Incoming spikes multiplied by synaptic weights (with or without + decaying spike activation). + """ + # Compute multiplication of pre-activations by connection weights + # s: batch, ch_in, w_in, h_in, d_in => s_unfold: batch, ch_in, ch_out * w_out * h_out * d_out, k1*k2*k3 + # w: ch_in, ch_out * w_out * h_out * d_out, k1*k2*k3 + # a_post: batch, ch_in, ch_out * w_out * h_out * d_out, k1*k2*k3 => batch, ch_out * w_out * h_out * d_out (= target.n) + + batch_size = s.shape[0] + + self.s_unfold = s.unfold( + -3,self.kernel_size[0],self.stride[0] + ).unfold( + -3,self.kernel_size[1],self.stride[1] + ).unfold( + -3,self.kernel_size[2],self.stride[2] + ).reshape( + s.shape[0], + self.in_channels, + self.conv_prod, + self.kernel_prod, + ).repeat( + 1, + 1, + self.n_filters, + 1, + ) + + a_post = self.s_unfold.to(self.w.device) * self.w + + return a_post.sum(-1).sum(1).view(batch_size, *self.target.shape) + + def update(self, **kwargs) -> None: + """ + Compute connection's update rule. + """ + super().update(**kwargs) + + def normalize(self) -> None: + """ + Normalize weights so each target neuron has sum of connection weights equal to + ``self.norm``. + """ + if self.norm is not None: + # get a view and modify in-place + # w: ch_in, ch_out * w_out * h_out * d_out, k1*k2*k3 + w = self.w.view( + self.w.shape[0]*self.w.shape[1], self.w.shape[2] + ) + + for fltr in range(w.shape[0]): + w[fltr,:] *= self.norm / w[fltr,:].sum(0) + + + def reset_state_variables(self) -> None: + """ + Contains resetting logic for the connection. + """ + super().reset_state_variables() + self.target.reset_state_variables() + class MeanFieldConnection(AbstractConnection): # language=rst """ diff --git a/bindsnet/utils.py b/bindsnet/utils.py index 05863662..628bf17d 100644 --- a/bindsnet/utils.py +++ b/bindsnet/utils.py @@ -214,3 +214,66 @@ def reshape_conv2d_weights(weights: torch.Tensor) -> torch.Tensor: ] = fltr return reshaped + + +def reshape_local_connection_2d_weights( + w: torch.Tensor, + n_filters: int, + kernel_size: Union[int, Tuple[int, int]], + conv_size: Union[int, Tuple[int, int]], + input_sqrt: Union[int, Tuple[int, int]], +) -> torch.Tensor: + # language=rst + """ + Reshape a slice of weights of a LocalConnection2D slice for plotting. + :param w: Slice of weights from a LocalConnection2D object. + :param n_filters: Number of filters (output channels). + :param kernel_size: Side length(s) of convolutional kernel. + :param conv_size: Side length(s) of convolution population. + :param input_sqrt: Sides length(s) of input neurons. + :return: A slice of LocalConnection2D weights reshaped as a collection of spatially ordered square grids. + """ + + k1, k2 = kernel_size + c1, c2 = conv_size + i1, i2 = input_sqrt + + fs = int(np.ceil(np.sqrt(n_filters))) + + w_ = torch.zeros((n_filters * k1, k2 * c1 * c2)) + + for n1 in range(c1): + for n2 in range(c2): + for feature in range(n_filters): + n = n1 * c2 + n2 + filter_ = w[feature, n1, n2, :, : + ].view(k1, k2) + w_[feature * k1 : (feature + 1) * k1, n * k2 : (n + 1) * k2] = filter_ + + if c1 == 1 and c2 == 1: + square = torch.zeros((i1 * fs, i2 * fs)) + + for n in range(n_filters): + square[ + (n // fs) * i1 : ((n // fs) + 1) * i2, + (n % fs) * i2 : ((n % fs) + 1) * i2, + ] = w_[n * i1 : (n + 1) * i2] + + return square + else: + square = torch.zeros((k1 * fs * c1, k2 * fs * c2)) + + for n1 in range(c1): + for n2 in range(c2): + for f1 in range(fs): + for f2 in range(fs): + if f1 * fs + f2 < n_filters: + square[ + k1 * (n1 * fs + f1) : k1 * (n1 * fs + f1 + 1), + k2 * (n2 * fs + f2) : k2 * (n2 * fs + f2 + 1), + ] = w_[ + (f1 * fs + f2) * k1 : (f1 * fs + f2 + 1) * k1, + (n1 * c2 + n2) * k2 : (n1 * c2 + n2 + 1) * k2, + ] + + return square \ No newline at end of file diff --git a/examples/mnist/conv_mnist.py b/examples/mnist/conv_mnist.py index a23f0b4b..227148cc 100644 --- a/examples/mnist/conv_mnist.py +++ b/examples/mnist/conv_mnist.py @@ -28,7 +28,7 @@ parser.add_argument("--n_epochs", type=int, default=1) parser.add_argument("--n_test", type=int, default=10000) parser.add_argument("--n_train", type=int, default=60000) -parser.add_argument("--batch_size", type=int, default=1) +parser.add_argument("--batch_size", type=int, default=10) parser.add_argument("--kernel_size", type=int, default=16) parser.add_argument("--stride", type=int, default=4) parser.add_argument("--n_filters", type=int, default=25) @@ -160,6 +160,8 @@ voltage_ims = None voltage_axes = None +plot = False + for epoch in range(n_epochs): if epoch % progress_interval == 0: print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) diff --git a/examples/mnist/loc1d_mnist.py b/examples/mnist/loc1d_mnist.py new file mode 100644 index 00000000..c96cb0cc --- /dev/null +++ b/examples/mnist/loc1d_mnist.py @@ -0,0 +1,170 @@ +### Toy example to test LocanConnection1D (the dataset used is MNIST but each image is raveled (each sample has shape (784,)). + +import torch +from torch.nn.modules.utils import _pair + + +from tqdm import tqdm +import os +from bindsnet.network.monitors import Monitor + +import torch +from torchvision import transforms +from tqdm import tqdm + + +from time import time as t +from torchvision import transforms +from bindsnet.learning import PostPre + +from bindsnet.network.nodes import AdaptiveLIFNodes +from bindsnet.network.nodes import Input +from bindsnet.network.network import Network +from bindsnet.network.topology import Connection, LocalConnection1D +from bindsnet.encoding import PoissonEncoder +from bindsnet.datasets import MNIST + + +# Hyperparameters +in_channels = 1 +n_filters = 25 +input_shape = 784 +kernel_size = 28*2 +stride = 28 +tc_theta_decay = 1e6 +theta_plus = 0.05 +norm = 0.2*kernel_size +wmin = 0.0 +wmax = 1.0 +nu = (1e-4, 1e-2) +inh = 25.0 +dt = 1.0 +time = 250 +intensity = 128 +n_epochs = 1 +n_train = 500 +progress_interval = 10 +batch_size = 1 + +# Build network +network = Network() + +input_layer = Input( + shape=[in_channels, input_shape], + traces=True, + tc_trace=20 +) + +compute_conv_size = lambda inp_size, k, s: int((inp_size - k) / s) + 1 +conv_size = compute_conv_size(input_shape, kernel_size, stride) + +output_layer = AdaptiveLIFNodes( + shape=[n_filters, conv_size], + traces=True, + rest=-65.0, + reset=-60.0, + thresh=-52.0, + refrac=5, + tc_decay=100.0, + tc_trace=20.0, + theta_plus=theta_plus, + tc_theta_decay=tc_theta_decay, +) + +input_output_conn = LocalConnection1D( + input_layer, + output_layer, + kernel_size=kernel_size, + stride=stride, + n_filters = n_filters, + nu=nu, + update_rule=PostPre, + wmin=wmin, + wmax=wmax, + norm=norm, +) + +w_inh_LC = torch.zeros(n_filters, conv_size, n_filters, conv_size) +for c in range(n_filters): + for w1 in range(conv_size): + w_inh_LC[c, w1, :, w1] = -inh + w_inh_LC[c, w1, c, w1] = 0 + +w_inh_LC = w_inh_LC.reshape(output_layer.n, output_layer.n) +recurrent_conn = Connection(output_layer, output_layer, w=w_inh_LC) + +network.add_layer(input_layer, name="X") +network.add_layer(output_layer, name="Y") +network.add_connection(input_output_conn, source="X", target="Y") +network.add_connection(recurrent_conn, source="Y", target="Y") + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +gpu = True +seed = 0 + +if gpu and torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) +else: + torch.manual_seed(seed) + device = "cpu" + if gpu: + gpu = False + +torch.set_num_threads(os.cpu_count() - 1) +print("Running on Device = ", device) + +if gpu: + network.to("cuda") + +# Load MNIST data. +train_dataset = MNIST( + PoissonEncoder(time=time, dt=dt), + None, + "../../data/MNIST", + download=True, + train=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)] + ), +) + +spikes = {} +for layer in set(network.layers): + spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time) + network.add_monitor(spikes[layer], name="%s_spikes" % layer) + +voltages = {} +for layer in set(network.layers) - {"X"}: + voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time) + network.add_monitor(voltages[layer], name="%s_voltages" % layer) + +# Train the network. +print("Begin training.\n") +start = t() + +for epoch in range(n_epochs): + if epoch % progress_interval == 0: + print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) + start = t() + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=gpu + ) + + for step, batch in enumerate(tqdm(train_dataloader)): + # Get next input sample. + if step > n_train: + break + inputs = {"X": batch["encoded_image"].view(time, batch_size, 1, 28 * 28)} + if gpu: + inputs = {k: v.cuda() for k, v in inputs.items()} + label = batch["label"] + + # Run the network on the input. + network.run(inputs=inputs, time=time, input_time_dim=1) + + + network.reset_state_variables() # Reset state variables. + +print("Progress: %d / %d (%.4f seconds)\n" % (n_epochs, n_epochs, t() - start)) +print("Training complete.\n") diff --git a/examples/mnist/loc2d_mnist.py b/examples/mnist/loc2d_mnist.py new file mode 100644 index 00000000..7acd29f8 --- /dev/null +++ b/examples/mnist/loc2d_mnist.py @@ -0,0 +1,182 @@ + +import torch +from torch.nn.modules.utils import _pair + + +from tqdm import tqdm +import os +from bindsnet.network.monitors import Monitor +import matplotlib.pyplot as plt +import torch +from torchvision import transforms +from tqdm import tqdm + + +from bindsnet.analysis.plotting import plot_local_connection_2d_weights + +from time import time as t +from torchvision import transforms +from bindsnet.learning import PostPre + +from bindsnet.network.nodes import AdaptiveLIFNodes +from bindsnet.network.nodes import Input +from bindsnet.network.network import Network +from bindsnet.network.topology import Connection, LocalConnection2D +from bindsnet.encoding import PoissonEncoder +from bindsnet.datasets import MNIST + + +# Hyperparameters +in_channels = 1 +n_filters = 50 +input_shape = [20, 20] +kernel_size = _pair(12) +stride = _pair(4) +tc_theta_decay = 1e6 +theta_plus = 0.05 +norm = 0.2*kernel_size[0]*kernel_size[1] +wmin = 0.0 +wmax = 1.0 +nu = (0.0001,0.01) +inh = 25.0 +dt = 1.0 +time = 250 +intensity = 128 +n_epochs = 1 +n_train = 2500 +progress_interval = 10 +batch_size = 1 + +plot = True + +# Build network +network = Network() + +input_layer = Input( + shape=[in_channels, input_shape[0], input_shape[1]], + traces=True, + tc_trace=20 +) + +compute_conv_size = lambda inp_size, k, s: int((inp_size - k) / s) + 1 +conv_size = _pair(compute_conv_size(input_shape[0], kernel_size[0], stride[0])) + +output_layer = AdaptiveLIFNodes( + shape=[n_filters, conv_size[0], conv_size[1]], + traces=True, + rest=-65.0, + reset=-60.0, + thresh=-52.0, + refrac=5, + tc_trace=20.0, + theta_plus=theta_plus, + tc_theta_decay=tc_theta_decay, +) + +input_output_conn = LocalConnection2D( + input_layer, + output_layer, + kernel_size=kernel_size, + stride=stride, + n_filters = n_filters, + nu=nu, + update_rule=PostPre, + wmin=wmin, + wmax=wmax, + norm=norm, +) + +w_inh_LC = torch.zeros(n_filters, conv_size[0], conv_size[1], n_filters, conv_size[0], conv_size[1]) +for c in range(n_filters): + for w1 in range(conv_size[0]): + for w2 in range(conv_size[0]): + w_inh_LC[c, w1, w2, :, w1, w2] = -inh + w_inh_LC[c, w1, w2, c, w1, w2] = 0 + +w_inh_LC = w_inh_LC.reshape(output_layer.n, output_layer.n) +recurrent_conn = Connection(output_layer, output_layer, w=w_inh_LC) + +network.add_layer(input_layer, name="X") +network.add_layer(output_layer, name="Y") +network.add_connection(input_output_conn, source="X", target="Y") +network.add_connection(recurrent_conn, source="Y", target="Y") + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +gpu = True +seed = 0 +if gpu and torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) +else: + torch.manual_seed(seed) + device = "cpu" + if gpu: + gpu = False + +torch.set_num_threads(os.cpu_count() - 1) +print("Running on Device = ", device) + +if gpu: + network.to("cuda") + +# Load MNIST data. +train_dataset = MNIST( + PoissonEncoder(time=time, dt=dt), + None, + "../../data/MNIST", + download=True, + train=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.CenterCrop((input_shape[0], input_shape[1])), transforms.Lambda(lambda x: x * intensity)] + ), +) + +spikes = {} +for layer in set(network.layers): + spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time) + network.add_monitor(spikes[layer], name="%s_spikes" % layer) + +voltages = {} +for layer in set(network.layers) - {"X"}: + voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time) + network.add_monitor(voltages[layer], name="%s_voltages" % layer) + +# Train the network. +print("Begin training.\n") +start = t() + +weights1_im = None + +for epoch in range(n_epochs): + if epoch % progress_interval == 0: + print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) + start = t() + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=gpu + ) + + for step, batch in enumerate(tqdm(train_dataloader)): + # Get next input sample. + if step > n_train: + break + inputs = {"X": batch["encoded_image"].view(time, batch_size, 1, input_shape[0], input_shape[1])} + if gpu: + inputs = {k: v.cuda() for k, v in inputs.items()} + label = batch["label"] + + # Run the network on the input. + network.run(inputs=inputs, time=time, input_time_dim=1) + + # Optionally plot various simulation information. + if plot: + weights1_im = plot_local_connection_2d_weights(network.connections[("X", "Y")], im=weights1_im) + plt.pause(1) + + network.reset_state_variables() # Reset state variables. + +print("Progress: %d / %d (%.4f seconds)\n" % (n_epochs, n_epochs, t() - start)) +print("Training complete.\n") + +weights1_im = plot_local_connection_2d_weights(network.connections[("X", "Y")]) +plt.savefig('test.png') +plt.pause(100) \ No newline at end of file diff --git a/examples/mnist/loc3d_mnist.py b/examples/mnist/loc3d_mnist.py new file mode 100644 index 00000000..4cf6c18b --- /dev/null +++ b/examples/mnist/loc3d_mnist.py @@ -0,0 +1,172 @@ +### Toy example to test LocalConnection3D (the dataset used is MNIST but with a dimension replicated +### for each image (each sample has size (28, 28, 28)) + +import torch +from torch.nn.modules.utils import _triple + + +from tqdm import tqdm +import os +from bindsnet.network.monitors import Monitor + +import torch +from torchvision import transforms +from tqdm import tqdm + + +from time import time as t +from torchvision import transforms +from bindsnet.learning import PostPre + +from bindsnet.network.nodes import AdaptiveLIFNodes +from bindsnet.network.nodes import Input +from bindsnet.network.network import Network +from bindsnet.network.topology import Connection, LocalConnection3D +from bindsnet.encoding import PoissonEncoder +from bindsnet.datasets import MNIST + + +# Hyperparameters +in_channels = 1 +n_filters = 25 +input_shape = [20, 20, 20] +kernel_size = _triple(16) +stride = _triple(2) +tc_theta_decay = 1e6 +theta_plus = 0.05 +norm = 0.2*kernel_size[0]*kernel_size[1]*kernel_size[2] +wmin = 0.0 +wmax = 1.0 +nu = (0.0001,0.01) +inh = 25.0 +dt = 1.0 +time = 250 +intensity = 128 +n_epochs = 1 +n_train = 2500 +progress_interval = 10 +batch_size = 1 + +# Build network +network = Network() + +input_layer = Input(n=input_shape[0]*input_shape[1]*input_shape[2], shape=(in_channels, input_shape[0], input_shape[1], input_shape[2]), traces=True) + +compute_conv_size = lambda inp_size, k, s: int((inp_size - k) / s) + 1 +conv_size = _triple(compute_conv_size(input_shape[0], kernel_size[0], stride[0])) + +output_layer = AdaptiveLIFNodes( + shape=[n_filters, conv_size[0], conv_size[1], conv_size[2]], + traces=True, + rest=-65.0, + reset=-60.0, + thresh=-52.0, + refrac=5, + tc_trace=20.0, + theta_plus=theta_plus, + tc_theta_decay=tc_theta_decay, +) + +input_output_conn = LocalConnection3D( + input_layer, + output_layer, + kernel_size=kernel_size, + stride=stride, + n_filters = n_filters, + nu=nu, + update_rule=PostPre, + wmin=wmin, + wmax=wmax, + norm=norm, +) + +w_inh_LC = torch.zeros(n_filters, conv_size[0], conv_size[1], conv_size[2], n_filters, conv_size[0], conv_size[1], conv_size[2]) + +for c in range(n_filters): + for w1 in range(conv_size[0]): + for w2 in range(conv_size[1]): + for w3 in range(conv_size[2]): + w_inh_LC[c, w1, w2, w3, :, w1, w2, w3] = -inh + w_inh_LC[c, w1, w2, w3, c, w1, w2, w3] = 0 + +w_inh_LC = w_inh_LC.reshape(output_layer.n, output_layer.n) +recurrent_conn = Connection(output_layer, output_layer, w=w_inh_LC) + +network.add_layer(input_layer, name="X") +network.add_layer(output_layer, name="Y") +network.add_connection(input_output_conn, source="X", target="Y") +network.add_connection(recurrent_conn, source="Y", target="Y") + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +gpu = True +seed = 0 +if gpu and torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) +else: + torch.manual_seed(seed) + device = "cpu" + if gpu: + gpu = False + +torch.set_num_threads(os.cpu_count() - 1) +print("Running on Device = ", device) + +if gpu: + network.to("cuda") + +# Load MNIST data. +train_dataset = MNIST( + PoissonEncoder(time=time, dt=dt), + None, + "../../data/MNIST", + download=True, + train=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.CenterCrop((input_shape[0], input_shape[1])), transforms.Lambda(lambda x: x * intensity)] + ), +) + +spikes = {} +for layer in set(network.layers): + spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time) + network.add_monitor(spikes[layer], name="%s_spikes" % layer) + +voltages = {} +for layer in set(network.layers) - {"X"}: + voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time) + network.add_monitor(voltages[layer], name="%s_voltages" % layer) + +# Train the network. +print("Begin training.\n") +start = t() + +for epoch in range(n_epochs): + if epoch % progress_interval == 0: + print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) + start = t() + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=gpu + ) + + for step, batch in enumerate(tqdm(train_dataloader)): + # Get next input sample. + if step > n_train: + break + inputs = { + "X": batch["encoded_image"] + .view(time, batch_size, 1, input_shape[0], input_shape[1]) + .unsqueeze(3) + .repeat(1, 1, 1, input_shape[2], 1, 1) + .float() + } + if gpu: + inputs = {k: v.cuda() for k, v in inputs.items()} + label = batch["label"] + + # Run the network on the input. + network.run(inputs=inputs, time=time, input_time_dim=1) + + +print("Progress: %d / %d (%.4f seconds)\n" % (n_epochs, n_epochs, t() - start)) +print("Training complete.\n") \ No newline at end of file From 134a983f88caf8802770b931d22287394a9e0925 Mon Sep 17 00:00:00 2001 From: Hafez Ghaemi Date: Tue, 18 Jan 2022 21:40:43 +0100 Subject: [PATCH 2/4] fix typo fix typo --- bindsnet/learning/learning.py | 2 +- examples/mnist/conv_mnist.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index fa3e15f6..5096c7c9 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -1199,7 +1199,7 @@ def _local_connection3d_update(self, **kwargs) -> None: def _conv1d_connection_update(self, **kwargs) -> None: # language=rst """ - Hebbian learning rule for ``Conv2dConnection`` subclass of + Hebbian learning rule for ``Conv1dConnection`` subclass of ``AbstractConnection`` class. """ out_channels, in_channels, kernel_size = self.connection.w.size() diff --git a/examples/mnist/conv_mnist.py b/examples/mnist/conv_mnist.py index 227148cc..a23f0b4b 100644 --- a/examples/mnist/conv_mnist.py +++ b/examples/mnist/conv_mnist.py @@ -28,7 +28,7 @@ parser.add_argument("--n_epochs", type=int, default=1) parser.add_argument("--n_test", type=int, default=10000) parser.add_argument("--n_train", type=int, default=60000) -parser.add_argument("--batch_size", type=int, default=10) +parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--kernel_size", type=int, default=16) parser.add_argument("--stride", type=int, default=4) parser.add_argument("--n_filters", type=int, default=25) @@ -160,8 +160,6 @@ voltage_ims = None voltage_axes = None -plot = False - for epoch in range(n_epochs): if epoch % progress_interval == 0: print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) From 6a3b80e5e69195fd57cc565c1e26bf8d08b447ce Mon Sep 17 00:00:00 2001 From: Hafez Ghaemi Date: Thu, 3 Feb 2022 19:07:46 +0100 Subject: [PATCH 3/4] Fix the maxPool batch size issue --- bindsnet/network/topology.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index ff3f57b6..33f6c79b 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -813,7 +813,7 @@ def reset_state_variables(self) -> None: """ super().reset_state_variables() - self.firing_rates = torch.zeros(self.source.s.shape) + self.firing_rates = torch.zeros(self.source.batch_size,*(self.source.s.shape[1:])) class MaxPool2dConnection(AbstractConnection): @@ -901,7 +901,7 @@ def reset_state_variables(self) -> None: """ super().reset_state_variables() - self.firing_rates = torch.zeros(self.source.s.shape) + self.firing_rates = torch.zeros(self.source.batch_size,*(self.source.s.shape[1:])) class MaxPoo3dConnection(AbstractConnection): @@ -989,7 +989,7 @@ def reset_state_variables(self) -> None: """ super().reset_state_variables() - self.firing_rates = torch.zeros(self.source.s.shape) + self.firing_rates = torch.zeros(self.source.batch_size,*(self.source.s.shape[1:])) class LocalConnection(AbstractConnection): From 9049bc63f9a61f184bbb50121e77a8a20d162041 Mon Sep 17 00:00:00 2001 From: Hananel Hazan Date: Sun, 13 Feb 2022 17:34:29 -0500 Subject: [PATCH 4/4] New local connection classes (1D, 2D, and 3D) Apply fix for #537 --- bindsnet/analysis/plotting.py | 39 +- bindsnet/datasets/spoken_mnist.py | 2 +- bindsnet/learning/learning.py | 1128 +++++++++++++++++---------- bindsnet/network/nodes.py | 10 +- bindsnet/network/topology.py | 214 +++-- bindsnet/utils.py | 5 +- examples/mnist/conv3d_MNIST.py | 2 +- examples/mnist/conv_mnist.py | 2 +- examples/mnist/loc1d_mnist.py | 23 +- examples/mnist/loc2d_mnist.py | 41 +- examples/mnist/loc3d_mnist.py | 39 +- examples/tensorboard/tensorboard.py | 2 +- 12 files changed, 915 insertions(+), 592 deletions(-) diff --git a/bindsnet/analysis/plotting.py b/bindsnet/analysis/plotting.py index b6909ee4..89007b08 100644 --- a/bindsnet/analysis/plotting.py +++ b/bindsnet/analysis/plotting.py @@ -9,7 +9,11 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable from torch.nn.modules.utils import _pair -from bindsnet.utils import reshape_conv2d_weights, reshape_locally_connected_weights, reshape_local_connection_2d_weights +from bindsnet.utils import ( + reshape_conv2d_weights, + reshape_locally_connected_weights, + reshape_local_connection_2d_weights, +) plt.ion() @@ -377,15 +381,17 @@ def plot_locally_connected_weights( return im -def plot_local_connection_2d_weights(lc : object, + +def plot_local_connection_2d_weights( + lc: object, input_channel: int = 0, output_channel: int = None, im: Optional[AxesImage] = None, lines: bool = True, figsize: Tuple[int, int] = (5, 5), cmap: str = "hot_r", - color: str='r', - ) -> AxesImage: + color: str = "r", +) -> AxesImage: # language=rst """ Plot a connection weight matrix of a :code:`Connection` with `locally connected @@ -400,15 +406,26 @@ def plot_local_connection_2d_weights(lc : object, """ n_sqrt = int(np.ceil(np.sqrt(lc.n_filters))) - sel_slice = lc.w.view(lc.in_channels, lc.n_filters, lc.conv_size[0], lc.conv_size[1], lc.kernel_size[0], lc.kernel_size[1]).cpu() + sel_slice = lc.w.view( + lc.in_channels, + lc.n_filters, + lc.conv_size[0], + lc.conv_size[1], + lc.kernel_size[0], + lc.kernel_size[1], + ).cpu() input_size = _pair(int(np.sqrt(lc.source.n))) if output_channel is None: sel_slice = sel_slice[input_channel, ...] - reshaped = reshape_local_connection_2d_weights(sel_slice, lc.n_filters, lc.kernel_size, lc.conv_size, input_size) + reshaped = reshape_local_connection_2d_weights( + sel_slice, lc.n_filters, lc.kernel_size, lc.conv_size, input_size + ) else: sel_slice = sel_slice[input_channel, output_channel, ...] sel_slice = sel_slice.unsqueeze(0) - reshaped = reshape_local_connection_2d_weights(sel_slice, 1, lc.kernel_size, lc.conv_size, input_size) + reshaped = reshape_local_connection_2d_weights( + sel_slice, 1, lc.kernel_size, lc.conv_size, input_size + ) if im == None: fig, ax = plt.subplots(figsize=figsize) @@ -416,7 +433,7 @@ def plot_local_connection_2d_weights(lc : object, div = make_axes_locatable(ax) cax = div.append_axes("right", size="5%", pad=0.05) - if lines and output_channel is None: + if lines and output_channel is None: for i in range( n_sqrt * lc.kernel_size[0], n_sqrt * lc.conv_size[0] * lc.kernel_size[0], @@ -430,7 +447,7 @@ def plot_local_connection_2d_weights(lc : object, n_sqrt * lc.kernel_size[1], ): ax.axvline(i - 0.5, color=color, linestyle="--") - + ax.set_xticks(()) ax.set_yticks(()) ax.set_aspect("auto") @@ -441,7 +458,7 @@ def plot_local_connection_2d_weights(lc : object, else: im.set_data(reshaped.cpu()) return im - + def plot_assignments( assignments: torch.Tensor, @@ -825,5 +842,3 @@ def plot_voltages( plt.tight_layout() return ims, axes - - diff --git a/bindsnet/datasets/spoken_mnist.py b/bindsnet/datasets/spoken_mnist.py index 7a86fee6..b8c6f544 100644 --- a/bindsnet/datasets/spoken_mnist.py +++ b/bindsnet/datasets/spoken_mnist.py @@ -250,7 +250,7 @@ def process_data( # Fast Fourier Transform and Power Spectrum NFFT = 512 mag_frames = np.absolute(np.fft.rfft(frames, NFFT)) # Magnitude of the FFT - pow_frames = (1.0 / NFFT) * (mag_frames ** 2) # Power Spectrum + pow_frames = (1.0 / NFFT) * (mag_frames**2) # Power Spectrum # Log filter banks nfilt = 40 diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index 5096c7c9..9e3b78e5 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -195,7 +195,7 @@ def __init__( raise NotImplementedError( "This learning rule is not supported for this Connection type." ) - + def _local_connection1d_update(self, **kwargs) -> None: # language=rst """ @@ -210,39 +210,54 @@ def _local_connection1d_update(self, **kwargs) -> None: out_channels = self.connection.n_filters height_out = self.connection.conv_size + target_x = self.target.x.reshape(batch_size, out_channels * height_out, 1) + target_x = target_x * torch.eye(out_channels * height_out).to( + self.connection.w.device + ) + source_s = ( + self.source.s.type(torch.float) + .unfold(-1, kernel_height, stride) + .reshape( + batch_size, + height_out, + in_channels * kernel_height, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) - target_x = self.target.x.reshape(batch_size, out_channels * height_out, 1) - target_x = target_x * torch.eye(out_channels * height_out).to(self.connection.w.device) - source_s = self.source.s.type(torch.float).unfold(-1, kernel_height, stride).reshape( - batch_size, - height_out, - in_channels * kernel_height, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) - - - target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out,1) - target_s = target_s * torch.eye(out_channels * height_out).to(self.connection.w.device) - source_x = self.source.x.unfold(-1, kernel_height, stride).reshape( - batch_size, - height_out, - in_channels * kernel_height, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) + target_s = self.target.s.type(torch.float).reshape( + batch_size, out_channels * height_out, 1 + ) + target_s = target_s * torch.eye(out_channels * height_out).to( + self.connection.w.device + ) + source_x = ( + self.source.x.unfold(-1, kernel_height, stride) + .reshape( + batch_size, + height_out, + in_channels * kernel_height, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) # Pre-synaptic update. if self.nu[0]: - pre = self.reduction(torch.bmm(target_x,source_s), dim=0) + pre = self.reduction(torch.bmm(target_x, source_s), dim=0) self.connection.w -= self.nu[0] * pre.view(self.connection.w.size()) # Post-synaptic update. if self.nu[1]: - post = self.reduction(torch.bmm(target_s, source_x),dim=0) + post = self.reduction(torch.bmm(target_s, source_x), dim=0) self.connection.w += self.nu[1] * post.view(self.connection.w.size()) super().update() @@ -263,37 +278,58 @@ def _local_connection2d_update(self, **kwargs) -> None: height_out = self.connection.conv_size[0] width_out = self.connection.conv_size[1] - target_x = self.target.x.reshape(batch_size, out_channels * height_out * width_out, 1) - target_x = target_x * torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) - source_s = self.source.s.type(torch.float).unfold(-2, kernel_height,stride[0]).unfold(-2, kernel_width, stride[1]).reshape( - batch_size, - height_out*width_out, - in_channels * kernel_height * kernel_width, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) + target_x = self.target.x.reshape( + batch_size, out_channels * height_out * width_out, 1 + ) + target_x = target_x * torch.eye(out_channels * height_out * width_out).to( + self.connection.w.device + ) + source_s = ( + self.source.s.type(torch.float) + .unfold(-2, kernel_height, stride[0]) + .unfold(-2, kernel_width, stride[1]) + .reshape( + batch_size, + height_out * width_out, + in_channels * kernel_height * kernel_width, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) - target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out*width_out,1) - target_s = target_s * torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) - source_x = self.source.x.unfold(-2, kernel_height,stride[0]).unfold(-2, kernel_width, stride[1]).reshape( - batch_size, - height_out*width_out, - in_channels * kernel_height * kernel_width, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) + target_s = self.target.s.type(torch.float).reshape( + batch_size, out_channels * height_out * width_out, 1 + ) + target_s = target_s * torch.eye(out_channels * height_out * width_out).to( + self.connection.w.device + ) + source_x = ( + self.source.x.unfold(-2, kernel_height, stride[0]) + .unfold(-2, kernel_width, stride[1]) + .reshape( + batch_size, + height_out * width_out, + in_channels * kernel_height * kernel_width, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) # Pre-synaptic update. if self.nu[0]: - pre = self.reduction(torch.bmm(target_x,source_s), dim=0) + pre = self.reduction(torch.bmm(target_x, source_s), dim=0) self.connection.w -= self.nu[0] * pre.view(self.connection.w.size()) # Post-synaptic update. if self.nu[1]: - post = self.reduction(torch.bmm(target_s, source_x),dim=0) + post = self.reduction(torch.bmm(target_s, source_x), dim=0) self.connection.w += self.nu[1] * post.view(self.connection.w.size()) super().update() @@ -316,39 +352,60 @@ def _local_connection3d_update(self, **kwargs) -> None: width_out = self.connection.conv_size[1] depth_out = self.connection.conv_size[2] - target_x = self.target.x.reshape(batch_size, out_channels * height_out * width_out * depth_out, 1) - target_x = target_x * torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) - source_s = self.source.s.type(torch.float).unfold(-3, kernel_height,stride[0]).unfold( - -3, kernel_width, stride[1]).unfold(-3, kernel_depth,stride[2]).reshape( - batch_size, - height_out*width_out*depth_out, - in_channels * kernel_height * kernel_width * kernel_depth, - ).repeat( - 1, - out_channels, - 1, + target_x = self.target.x.reshape( + batch_size, out_channels * height_out * width_out * depth_out, 1 + ) + target_x = target_x * torch.eye( + out_channels * height_out * width_out * depth_out ).to(self.connection.w.device) + source_s = ( + self.source.s.type(torch.float) + .unfold(-3, kernel_height, stride[0]) + .unfold(-3, kernel_width, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + height_out * width_out * depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) - target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out*width_out*depth_out,1) - target_s = target_s * torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) - source_x = self.source.x.unfold(-3, kernel_height,stride[0]).unfold( - -3, kernel_width, stride[1]).unfold(-3, kernel_depth, stride[2]).reshape( - batch_size, - height_out*width_out*depth_out, - in_channels * kernel_height * kernel_width * kernel_depth, - ).repeat( - 1, - out_channels, - 1, + target_s = self.target.s.type(torch.float).reshape( + batch_size, out_channels * height_out * width_out * depth_out, 1 + ) + target_s = target_s * torch.eye( + out_channels * height_out * width_out * depth_out ).to(self.connection.w.device) + source_x = ( + self.source.x.unfold(-3, kernel_height, stride[0]) + .unfold(-3, kernel_width, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + height_out * width_out * depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) # Pre-synaptic update. if self.nu[0]: - pre = self.reduction(torch.bmm(target_x,source_s), dim=0) + pre = self.reduction(torch.bmm(target_x, source_s), dim=0) self.connection.w -= self.nu[0] * pre.view(self.connection.w.size()) # Post-synaptic update. if self.nu[1]: - post = self.reduction(torch.bmm(target_s, source_x),dim=0) + post = self.reduction(torch.bmm(target_s, source_x), dim=0) self.connection.w += self.nu[1] * post.view(self.connection.w.size()) super().update() @@ -623,42 +680,65 @@ def _local_connection1d_update(self, **kwargs) -> None: out_channels = self.connection.n_filters height_out = self.connection.conv_size + target_x = self.target.x.reshape(batch_size, out_channels * height_out, 1) + target_x = target_x * torch.eye(out_channels * height_out).to( + self.connection.w.device + ) + source_s = ( + self.source.s.type(torch.float) + .unfold(-1, kernel_height, stride) + .reshape( + batch_size, + height_out, + in_channels * kernel_height, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) - target_x = self.target.x.reshape(batch_size, out_channels * height_out, 1) - target_x = target_x * torch.eye(out_channels * height_out).to(self.connection.w.device) - source_s = self.source.s.type(torch.float).unfold(-1, kernel_height, stride).reshape( - batch_size, - height_out, - in_channels * kernel_height, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) - - - target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out,1) - target_s = target_s * torch.eye(out_channels * height_out).to(self.connection.w.device) - source_x = self.source.x.unfold(-1, kernel_height, stride).reshape( - batch_size, - height_out, - in_channels * kernel_height, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) + target_s = self.target.s.type(torch.float).reshape( + batch_size, out_channels * height_out, 1 + ) + target_s = target_s * torch.eye(out_channels * height_out).to( + self.connection.w.device + ) + source_x = ( + self.source.x.unfold(-1, kernel_height, stride) + .reshape( + batch_size, + height_out, + in_channels * kernel_height, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) update = 0 # Pre-synaptic update. if self.nu[0]: - pre = self.reduction(torch.bmm(target_x,source_s), dim=0) - update -= self.nu[0] * pre.view(self.connection.w.size()) * (self.connection.w - self.wmin) + pre = self.reduction(torch.bmm(target_x, source_s), dim=0) + update -= ( + self.nu[0] + * pre.view(self.connection.w.size()) + * (self.connection.w - self.wmin) + ) # Post-synaptic update. if self.nu[1]: - post = self.reduction(torch.bmm(target_s, source_x),dim=0) - update += self.nu[1] * post.view(self.connection.w.size()) * (self.wmax - self.connection.w) + post = self.reduction(torch.bmm(target_s, source_x), dim=0) + update += ( + self.nu[1] + * post.view(self.connection.w.size()) + * (self.wmax - self.connection.w) + ) self.connection.w += update @@ -680,40 +760,69 @@ def _local_connection2d_update(self, **kwargs) -> None: height_out = self.connection.conv_size[0] width_out = self.connection.conv_size[1] - target_x = self.target.x.reshape(batch_size, out_channels * height_out * width_out, 1) - target_x = target_x * torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) - source_s = self.source.s.type(torch.float).unfold(-2, kernel_height,stride[0]).unfold(-2, kernel_width, stride[1]).reshape( - batch_size, - height_out*width_out, - in_channels * kernel_height * kernel_width, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) + target_x = self.target.x.reshape( + batch_size, out_channels * height_out * width_out, 1 + ) + target_x = target_x * torch.eye(out_channels * height_out * width_out).to( + self.connection.w.device + ) + source_s = ( + self.source.s.type(torch.float) + .unfold(-2, kernel_height, stride[0]) + .unfold(-2, kernel_width, stride[1]) + .reshape( + batch_size, + height_out * width_out, + in_channels * kernel_height * kernel_width, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) - target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out*width_out,1) - target_s = target_s * torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) - source_x = self.source.x.unfold(-2, kernel_height,stride[0]).unfold(-2, kernel_width, stride[1]).reshape( - batch_size, - height_out*width_out, - in_channels * kernel_height * kernel_width, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) + target_s = self.target.s.type(torch.float).reshape( + batch_size, out_channels * height_out * width_out, 1 + ) + target_s = target_s * torch.eye(out_channels * height_out * width_out).to( + self.connection.w.device + ) + source_x = ( + self.source.x.unfold(-2, kernel_height, stride[0]) + .unfold(-2, kernel_width, stride[1]) + .reshape( + batch_size, + height_out * width_out, + in_channels * kernel_height * kernel_width, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) update = 0 # Pre-synaptic update. if self.nu[0]: - pre = self.reduction(torch.bmm(target_x,source_s), dim=0) - update -= self.nu[0] * pre.view(self.connection.w.size()) * (self.connection.w - self.wmin) + pre = self.reduction(torch.bmm(target_x, source_s), dim=0) + update -= ( + self.nu[0] + * pre.view(self.connection.w.size()) + * (self.connection.w - self.wmin) + ) # Post-synaptic update. if self.nu[1]: - post = self.reduction(torch.bmm(target_s, source_x),dim=0) - update += self.nu[1] * post.view(self.connection.w.size()) * (self.wmax - self.connection.w) + post = self.reduction(torch.bmm(target_s, source_x), dim=0) + update += ( + self.nu[1] + * post.view(self.connection.w.size()) + * (self.wmax - self.connection.w) + ) self.connection.w += update @@ -737,42 +846,71 @@ def _local_connection3d_update(self, **kwargs) -> None: width_out = self.connection.conv_size[1] depth_out = self.connection.conv_size[2] - target_x = self.target.x.reshape(batch_size, out_channels * height_out * width_out * depth_out, 1) - target_x = target_x * torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) - source_s = self.source.s.type(torch.float).unfold(-3, kernel_height,stride[0]).unfold( - -3, kernel_width, stride[1]).unfold(-3, kernel_depth,stride[2]).reshape( - batch_size, - height_out*width_out*depth_out, - in_channels * kernel_height * kernel_width * kernel_depth, - ).repeat( - 1, - out_channels, - 1, + target_x = self.target.x.reshape( + batch_size, out_channels * height_out * width_out * depth_out, 1 + ) + target_x = target_x * torch.eye( + out_channels * height_out * width_out * depth_out ).to(self.connection.w.device) + source_s = ( + self.source.s.type(torch.float) + .unfold(-3, kernel_height, stride[0]) + .unfold(-3, kernel_width, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + height_out * width_out * depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) - target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out*width_out*depth_out,1) - target_s = target_s * torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) - source_x = self.source.x.unfold(-3, kernel_height,stride[0]).unfold( - -3, kernel_width, stride[1]).unfold(-3, kernel_depth, stride[2]).reshape( - batch_size, - height_out*width_out*depth_out, - in_channels * kernel_height * kernel_width * kernel_depth, - ).repeat( - 1, - out_channels, - 1, + target_s = self.target.s.type(torch.float).reshape( + batch_size, out_channels * height_out * width_out * depth_out, 1 + ) + target_s = target_s * torch.eye( + out_channels * height_out * width_out * depth_out ).to(self.connection.w.device) + source_x = ( + self.source.x.unfold(-3, kernel_height, stride[0]) + .unfold(-3, kernel_width, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + height_out * width_out * depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) update = 0 # Pre-synaptic update. if self.nu[0]: - pre = self.reduction(torch.bmm(target_x,source_s), dim=0) - update -= self.nu[0] * pre.view(self.connection.w.size()) * (self.connection.w - self.wmin) + pre = self.reduction(torch.bmm(target_x, source_s), dim=0) + update -= ( + self.nu[0] + * pre.view(self.connection.w.size()) + * (self.connection.w - self.wmin) + ) # Post-synaptic update. if self.nu[1]: - post = self.reduction(torch.bmm(target_s, source_x),dim=0) - update += self.nu[1] * post.view(self.connection.w.size()) * (self.wmax - self.connection.w) + post = self.reduction(torch.bmm(target_s, source_x), dim=0) + update += ( + self.nu[1] + * post.view(self.connection.w.size()) + * (self.wmax - self.connection.w) + ) self.connection.w += update @@ -829,8 +967,6 @@ def _conv1d_connection_update(self, **kwargs) -> None: super().update() - - def _conv2d_connection_update(self, **kwargs) -> None: # language=rst """ @@ -1056,38 +1192,53 @@ def _local_connection1d_update(self, **kwargs) -> None: out_channels = self.connection.n_filters height_out = self.connection.conv_size + target_x = self.target.x.reshape(batch_size, out_channels * height_out, 1) + target_x = target_x * torch.eye(out_channels * height_out).to( + self.connection.w.device + ) + source_s = ( + self.source.s.type(torch.float) + .unfold(-1, kernel_height, stride) + .reshape( + batch_size, + height_out, + in_channels * kernel_height, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) - target_x = self.target.x.reshape(batch_size, out_channels * height_out, 1) - target_x = target_x * torch.eye(out_channels * height_out).to(self.connection.w.device) - source_s = self.source.s.type(torch.float).unfold(-1, kernel_height, stride).reshape( - batch_size, - height_out, - in_channels * kernel_height, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) - - - target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out,1) - target_s = target_s * torch.eye(out_channels * height_out).to(self.connection.w.device) - source_x = self.source.x.unfold(-1, kernel_height, stride).reshape( - batch_size, - height_out, - in_channels * kernel_height, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) + target_s = self.target.s.type(torch.float).reshape( + batch_size, out_channels * height_out, 1 + ) + target_s = target_s * torch.eye(out_channels * height_out).to( + self.connection.w.device + ) + source_x = ( + self.source.x.unfold(-1, kernel_height, stride) + .reshape( + batch_size, + height_out, + in_channels * kernel_height, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) # Pre-synaptic update. - pre = self.reduction(torch.bmm(target_x,source_s), dim=0) + pre = self.reduction(torch.bmm(target_x, source_s), dim=0) self.connection.w += self.nu[0] * pre.view(self.connection.w.size()) # Post-synaptic update. - post = self.reduction(torch.bmm(target_s, source_x),dim=0) + post = self.reduction(torch.bmm(target_s, source_x), dim=0) self.connection.w += self.nu[1] * post.view(self.connection.w.size()) super().update() @@ -1108,36 +1259,57 @@ def _local_connection2d_update(self, **kwargs) -> None: height_out = self.connection.conv_size[0] width_out = self.connection.conv_size[1] - target_x = self.target.x.reshape(batch_size, out_channels * height_out * width_out, 1) - target_x = target_x * torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) - source_s = self.source.s.type(torch.float).unfold(-2, kernel_height,stride[0]).unfold(-2, kernel_width, stride[1]).reshape( - batch_size, - height_out*width_out, - in_channels * kernel_height * kernel_width, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) + target_x = self.target.x.reshape( + batch_size, out_channels * height_out * width_out, 1 + ) + target_x = target_x * torch.eye(out_channels * height_out * width_out).to( + self.connection.w.device + ) + source_s = ( + self.source.s.type(torch.float) + .unfold(-2, kernel_height, stride[0]) + .unfold(-2, kernel_width, stride[1]) + .reshape( + batch_size, + height_out * width_out, + in_channels * kernel_height * kernel_width, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) - target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out*width_out,1) - target_s = target_s * torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) - source_x = self.source.x.unfold(-2, kernel_height,stride[0]).unfold(-2, kernel_width, stride[1]).reshape( - batch_size, - height_out*width_out, - in_channels * kernel_height * kernel_width, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) + target_s = self.target.s.type(torch.float).reshape( + batch_size, out_channels * height_out * width_out, 1 + ) + target_s = target_s * torch.eye(out_channels * height_out * width_out).to( + self.connection.w.device + ) + source_x = ( + self.source.x.unfold(-2, kernel_height, stride[0]) + .unfold(-2, kernel_width, stride[1]) + .reshape( + batch_size, + height_out * width_out, + in_channels * kernel_height * kernel_width, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) # Pre-synaptic update. - pre = self.reduction(torch.bmm(target_x,source_s), dim=0) + pre = self.reduction(torch.bmm(target_x, source_s), dim=0) self.connection.w += self.nu[0] * pre.view(self.connection.w.size()) # Post-synaptic update. - post = self.reduction(torch.bmm(target_s, source_x),dim=0) + post = self.reduction(torch.bmm(target_s, source_x), dim=0) self.connection.w += self.nu[1] * post.view(self.connection.w.size()) super().update() @@ -1160,38 +1332,59 @@ def _local_connection3d_update(self, **kwargs) -> None: width_out = self.connection.conv_size[1] depth_out = self.connection.conv_size[2] - target_x = self.target.x.reshape(batch_size, out_channels * height_out * width_out * depth_out, 1) - target_x = target_x * torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) - source_s = self.source.s.type(torch.float).unfold(-3, kernel_height,stride[0]).unfold( - -3, kernel_width, stride[1]).unfold(-3, kernel_depth,stride[2]).reshape( - batch_size, - height_out*width_out*depth_out, - in_channels * kernel_height * kernel_width * kernel_depth, - ).repeat( - 1, - out_channels, - 1, + target_x = self.target.x.reshape( + batch_size, out_channels * height_out * width_out * depth_out, 1 + ) + target_x = target_x * torch.eye( + out_channels * height_out * width_out * depth_out ).to(self.connection.w.device) + source_s = ( + self.source.s.type(torch.float) + .unfold(-3, kernel_height, stride[0]) + .unfold(-3, kernel_width, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + height_out * width_out * depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) - target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out*width_out*depth_out,1) - target_s = target_s * torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) - source_x = self.source.x.unfold(-3, kernel_height,stride[0]).unfold( - -3, kernel_width, stride[1]).unfold(-3, kernel_depth, stride[2]).reshape( - batch_size, - height_out*width_out*depth_out, - in_channels * kernel_height * kernel_width * kernel_depth, - ).repeat( - 1, - out_channels, - 1, + target_s = self.target.s.type(torch.float).reshape( + batch_size, out_channels * height_out * width_out * depth_out, 1 + ) + target_s = target_s * torch.eye( + out_channels * height_out * width_out * depth_out ).to(self.connection.w.device) + source_x = ( + self.source.x.unfold(-3, kernel_height, stride[0]) + .unfold(-3, kernel_width, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + height_out * width_out * depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) # Pre-synaptic update. - pre = self.reduction(torch.bmm(target_x,source_s), dim=0) + pre = self.reduction(torch.bmm(target_x, source_s), dim=0) self.connection.w += self.nu[0] * pre.view(self.connection.w.size()) # Post-synaptic update. - post = self.reduction(torch.bmm(target_s, source_x),dim=0) + post = self.reduction(torch.bmm(target_s, source_x), dim=0) self.connection.w += self.nu[1] * post.view(self.connection.w.size()) super().update() @@ -1486,39 +1679,56 @@ def _local_connection1d_update(self, **kwargs) -> None: self.p_plus = torch.zeros( batch_size, *self.source.shape, device=self.connection.w.device ) - self.p_plus = self.p_plus.unfold(-1, kernel_height,stride).reshape( - batch_size, - height_out, - in_channels * kernel_height, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) - + self.p_plus = ( + self.p_plus.unfold(-1, kernel_height, stride) + .reshape( + batch_size, + height_out, + in_channels * kernel_height, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) + if not hasattr(self, "p_minus"): self.p_minus = torch.zeros( batch_size, *self.target.shape, device=self.connection.w.device ) - self.p_minus = self.p_minus.reshape(batch_size,\ - out_channels * height_out, 1) - self.p_minus = self.p_minus *\ - torch.eye(out_channels * height_out).to(self.connection.w.device) + self.p_minus = self.p_minus.reshape( + batch_size, out_channels * height_out, 1 + ) + self.p_minus = self.p_minus * torch.eye(out_channels * height_out).to( + self.connection.w.device + ) # Reshaping spike occurrences. - source_s = self.source.s.type(torch.float).unfold(-1, kernel_height, stride).reshape( - batch_size, - height_out, - in_channels*kernel_height, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) + source_s = ( + self.source.s.type(torch.float) + .unfold(-1, kernel_height, stride) + .reshape( + batch_size, + height_out, + in_channels * kernel_height, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) + + target_s = self.target.s.type(torch.float).reshape( + batch_size, out_channels * height_out, 1 + ) + target_s = target_s * torch.eye(out_channels * height_out).to( + self.connection.w.device + ) - target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out,1) - target_s = target_s * torch.eye(out_channels*height_out).to(self.connection.w.device) - # Update P^+ and P^- values. self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) self.p_plus += a_plus * source_s @@ -1526,15 +1736,14 @@ def _local_connection1d_update(self, **kwargs) -> None: self.p_minus += a_minus * target_s # Calculate point eligibility value. - self.eligibility = torch.bmm( - target_s, self.p_plus - ) + torch.bmm(self.p_minus, source_s) + self.eligibility = torch.bmm(target_s, self.p_plus) + torch.bmm( + self.p_minus, source_s + ) self.eligibility = self.eligibility.view(batch_size, *self.connection.w.shape) super().update() - def _local_connection2d_update(self, **kwargs) -> None: # language=rst """ @@ -1551,7 +1760,6 @@ def _local_connection2d_update(self, **kwargs) -> None: height_out = self.connection.conv_size[0] width_out = self.connection.conv_size[1] - # Initialize eligibility. if not hasattr(self, "eligibility"): self.eligibility = torch.zeros( @@ -1577,39 +1785,58 @@ def _local_connection2d_update(self, **kwargs) -> None: self.p_plus = torch.zeros( batch_size, *self.source.shape, device=self.connection.w.device ) - self.p_plus = self.p_plus.unfold(-2, kernel_height,stride[0]).unfold(-2, kernel_width, stride[1]).reshape( - batch_size, + self.p_plus = ( + self.p_plus.unfold(-2, kernel_height, stride[0]) + .unfold(-2, kernel_width, stride[1]) + .reshape( + batch_size, + height_out * width_out, + in_channels * kernel_height * kernel_width, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) + + if not hasattr(self, "p_minus"): + self.p_minus = torch.zeros( + batch_size, *self.target.shape, device=self.connection.w.device + ) + self.p_minus = self.p_minus.reshape( + batch_size, out_channels * height_out * width_out, 1 + ) + self.p_minus = self.p_minus * torch.eye( + out_channels * height_out * width_out + ).to(self.connection.w.device) + + # Reshaping spike occurrences. + source_s = ( + self.source.s.type(torch.float) + .unfold(-2, kernel_height, stride[0]) + .unfold(-2, kernel_width, stride[1]) + .reshape( + batch_size, height_out * width_out, in_channels * kernel_height * kernel_width, - ).repeat( + ) + .repeat( 1, out_channels, 1, - ).to(self.connection.w.device) - - if not hasattr(self, "p_minus"): - self.p_minus = torch.zeros( - batch_size, *self.target.shape, device=self.connection.w.device ) - self.p_minus = self.p_minus.reshape(batch_size,\ - out_channels * height_out * width_out, 1) - self.p_minus = self.p_minus *\ - torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) + .to(self.connection.w.device) + ) - # Reshaping spike occurrences. - source_s = self.source.s.type(torch.float).unfold(-2, kernel_height, stride[0]).unfold(-2, kernel_width, stride[1]).reshape( - batch_size, - height_out*width_out, - in_channels*kernel_height*kernel_width, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) + target_s = self.target.s.type(torch.float).reshape( + batch_size, out_channels * height_out * width_out, 1 + ) + target_s = target_s * torch.eye(out_channels * height_out * width_out).to( + self.connection.w.device + ) - target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out*width_out, 1) - target_s = target_s * torch.eye(out_channels*height_out*width_out).to(self.connection.w.device) - # Update P^+ and P^- values. self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) self.p_plus += a_plus * source_s @@ -1617,15 +1844,14 @@ def _local_connection2d_update(self, **kwargs) -> None: self.p_minus += a_minus * target_s # Calculate point eligibility value. - self.eligibility = torch.bmm( - target_s, self.p_plus - ) + torch.bmm(self.p_minus, source_s) + self.eligibility = torch.bmm(target_s, self.p_plus) + torch.bmm( + self.p_minus, source_s + ) self.eligibility = self.eligibility.view(batch_size, *self.connection.w.shape) super().update() - def _local_connection3d_update(self, **kwargs) -> None: # language=rst """ @@ -1645,7 +1871,6 @@ def _local_connection3d_update(self, **kwargs) -> None: width_out = self.connection.conv_size[1] depth_out = self.connection.conv_size[2] - # Initialize eligibility. if not hasattr(self, "eligibility"): self.eligibility = torch.zeros( @@ -1670,41 +1895,60 @@ def _local_connection3d_update(self, **kwargs) -> None: self.p_plus = torch.zeros( batch_size, *self.source.shape, device=self.connection.w.device ) - self.p_plus = self.p_plus.unfold(-3, kernel_height,stride[0]).unfold( - -3, kernel_width, stride[1]).unfold(-3, kernel_depth,stride[2]).reshape( - batch_size, - height_out * width_out * depth_out, - in_channels * kernel_height * kernel_width * kernel_depth, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) - + self.p_plus = ( + self.p_plus.unfold(-3, kernel_height, stride[0]) + .unfold(-3, kernel_width, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + height_out * width_out * depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) + if not hasattr(self, "p_minus"): self.p_minus = torch.zeros( batch_size, *self.target.shape, device=self.connection.w.device ) - self.p_minus = self.p_minus.reshape(batch_size,\ - out_channels * height_out * width_out * depth_out, 1) - self.p_minus = self.p_minus *\ - torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) + self.p_minus = self.p_minus.reshape( + batch_size, out_channels * height_out * width_out * depth_out, 1 + ) + self.p_minus = self.p_minus * torch.eye( + out_channels * height_out * width_out * depth_out + ).to(self.connection.w.device) # Reshaping spike occurrences. - source_s = self.source.s.type(torch.float).unfold(-3, kernel_height, stride[0]).unfold( - -3, kernel_width, stride[1]).unfold(-3, kernel_depth, stride[2]).reshape( - batch_size, - height_out*width_out*depth_out, - in_channels*kernel_height*kernel_width*kernel_depth, - ).repeat( - 1, - out_channels, - 1, + source_s = ( + self.source.s.type(torch.float) + .unfold(-3, kernel_height, stride[0]) + .unfold(-3, kernel_width, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + height_out * width_out * depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) + + target_s = self.target.s.type(torch.float).reshape( + batch_size, out_channels * height_out * width_out * depth_out, 1 + ) + target_s = target_s * torch.eye( + out_channels * height_out * width_out * depth_out ).to(self.connection.w.device) - target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels*height_out*width_out*depth_out,1) - target_s = target_s * torch.eye(out_channels*height_out*width_out*depth_out).to(self.connection.w.device) - # Update P^+ and P^- values. self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) self.p_plus += a_plus * source_s @@ -1712,15 +1956,14 @@ def _local_connection3d_update(self, **kwargs) -> None: self.p_minus += a_minus * target_s # Calculate point eligibility value. - self.eligibility = torch.bmm( - target_s, self.p_plus - ) + torch.bmm(self.p_minus, source_s) + self.eligibility = torch.bmm(target_s, self.p_plus) + torch.bmm( + self.p_minus, source_s + ) self.eligibility = self.eligibility.view(batch_size, *self.connection.w.shape) super().update() - def _conv1d_connection_update(self, **kwargs) -> None: # language=rst """ @@ -2153,40 +2396,57 @@ def _local_connection1d_update(self, **kwargs) -> None: self.p_plus = torch.zeros( batch_size, *self.source.shape, device=self.connection.w.device ) - self.p_plus = self.p_plus.unfold(-1, kernel_height, stride).reshape( - batch_size, + self.p_plus = ( + self.p_plus.unfold(-1, kernel_height, stride) + .reshape( + batch_size, + height_out, + in_channels * kernel_height, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) + + if not hasattr(self, "p_minus"): + self.p_minus = torch.zeros( + batch_size, *self.target.shape, device=self.connection.w.device + ) + self.p_minus = self.p_minus.reshape( + batch_size, out_channels * height_out, 1 + ) + self.p_minus = self.p_minus * torch.eye(out_channels * height_out).to( + self.connection.w.device + ) + + # Reshaping spike occurrences. + source_s = ( + self.source.s.type(torch.float) + .unfold(-1, kernel_height, stride) + .reshape( + batch_size, height_out, in_channels * kernel_height, - ).repeat( + ) + .repeat( 1, out_channels, 1, - ).to(self.connection.w.device) - - if not hasattr(self, "p_minus"): - self.p_minus = torch.zeros( - batch_size, *self.target.shape, device=self.connection.w.device ) - self.p_minus = self.p_minus.reshape(batch_size,\ - out_channels * height_out, 1) - self.p_minus = self.p_minus *\ - torch.eye(out_channels * height_out).to(self.connection.w.device) + .to(self.connection.w.device) + ) - # Reshaping spike occurrences. - source_s = self.source.s.type(torch.float).unfold(-1, kernel_height,stride).reshape( - batch_size, - height_out, - in_channels * kernel_height, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) - # print(target_x.shape, source_s.shape) - target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels * height_out,1) - target_s = target_s * torch.eye(out_channels * height_out).to(self.connection.w.device) - + target_s = self.target.s.type(torch.float).reshape( + batch_size, out_channels * height_out, 1 + ) + target_s = target_s * torch.eye(out_channels * height_out).to( + self.connection.w.device + ) + # Update P^+ and P^- values. self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) self.p_plus += a_plus * source_s @@ -2194,9 +2454,9 @@ def _local_connection1d_update(self, **kwargs) -> None: self.p_minus += a_minus * target_s # Calculate point eligibility value. - self.eligibility = torch.bmm( - target_s, self.p_plus - ) + torch.bmm(self.p_minus, source_s) + self.eligibility = torch.bmm(target_s, self.p_plus) + torch.bmm( + self.p_minus, source_s + ) self.eligibility = self.eligibility.view(batch_size, *self.connection.w.shape) super().update() @@ -2218,8 +2478,6 @@ def _local_connection2d_update(self, **kwargs) -> None: height_out = self.connection.conv_size[0] width_out = self.connection.conv_size[1] - - # Initialize eligibility. if not hasattr(self, "eligibility"): self.eligibility = torch.zeros( @@ -2253,40 +2511,59 @@ def _local_connection2d_update(self, **kwargs) -> None: self.p_plus = torch.zeros( batch_size, *self.source.shape, device=self.connection.w.device ) - self.p_plus = self.p_plus.unfold(-2, kernel_height, stride[0]).unfold(-2, kernel_width, stride[1]).reshape( - batch_size, - height_out*width_out, - in_channels * kernel_height * kernel_width, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) - + self.p_plus = ( + self.p_plus.unfold(-2, kernel_height, stride[0]) + .unfold(-2, kernel_width, stride[1]) + .reshape( + batch_size, + height_out * width_out, + in_channels * kernel_height * kernel_width, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) + if not hasattr(self, "p_minus"): self.p_minus = torch.zeros( batch_size, *self.target.shape, device=self.connection.w.device ) - self.p_minus = self.p_minus.reshape(batch_size,\ - out_channels * height_out * width_out, 1) - self.p_minus = self.p_minus *\ - torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) + self.p_minus = self.p_minus.reshape( + batch_size, out_channels * height_out * width_out, 1 + ) + self.p_minus = self.p_minus * torch.eye( + out_channels * height_out * width_out + ).to(self.connection.w.device) # Reshaping spike occurrences. - source_s = self.source.s.type(torch.float).unfold(-2, kernel_height,stride[0]).unfold(-2, kernel_width, stride[1]).reshape( - batch_size, - height_out*width_out, - in_channels * kernel_height * kernel_width, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) - + source_s = ( + self.source.s.type(torch.float) + .unfold(-2, kernel_height, stride[0]) + .unfold(-2, kernel_width, stride[1]) + .reshape( + batch_size, + height_out * width_out, + in_channels * kernel_height * kernel_width, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) + # print(target_x.shape, source_s.shape) - target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels * height_out * width_out,1) - target_s = target_s * torch.eye(out_channels * height_out * width_out).to(self.connection.w.device) - + target_s = self.target.s.type(torch.float).reshape( + batch_size, out_channels * height_out * width_out, 1 + ) + target_s = target_s * torch.eye(out_channels * height_out * width_out).to( + self.connection.w.device + ) + # Update P^+ and P^- values. self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) self.p_plus += a_plus * source_s @@ -2294,9 +2571,9 @@ def _local_connection2d_update(self, **kwargs) -> None: self.p_minus += a_minus * target_s # Calculate point eligibility value. - self.eligibility = torch.bmm( - target_s, self.p_plus - ) + torch.bmm(self.p_minus, source_s) + self.eligibility = torch.bmm(target_s, self.p_plus) + torch.bmm( + self.p_minus, source_s + ) self.eligibility = self.eligibility.view(batch_size, *self.connection.w.shape) super().update() @@ -2353,42 +2630,61 @@ def _local_connection3d_update(self, **kwargs) -> None: self.p_plus = torch.zeros( batch_size, *self.source.shape, device=self.connection.w.device ) - self.p_plus = self.p_plus.unfold(-3, kernel_height, stride[0]).unfold( - -3, kernel_width, stride[1]).unfold(-3, kernel_depth, stride[2]).reshape( - batch_size, - height_out*width_out*depth_out, - in_channels * kernel_height * kernel_width*kernel_depth, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) - + self.p_plus = ( + self.p_plus.unfold(-3, kernel_height, stride[0]) + .unfold(-3, kernel_width, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + height_out * width_out * depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) + if not hasattr(self, "p_minus"): self.p_minus = torch.zeros( batch_size, *self.target.shape, device=self.connection.w.device ) - self.p_minus = self.p_minus.reshape(batch_size,\ - out_channels * height_out * width_out * depth_out, 1) - self.p_minus = self.p_minus *\ - torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) + self.p_minus = self.p_minus.reshape( + batch_size, out_channels * height_out * width_out * depth_out, 1 + ) + self.p_minus = self.p_minus * torch.eye( + out_channels * height_out * width_out * depth_out + ).to(self.connection.w.device) # Reshaping spike occurrences. - source_s = self.source.s.type(torch.float).unfold(-3, kernel_height,stride[0]).unfold( - -3, kernel_width, stride[1]).unfold(-3, kernel_depth, stride[2]).reshape( - batch_size, - height_out*width_out*depth_out, - in_channels * kernel_height * kernel_width * kernel_depth, - ).repeat( - 1, - out_channels, - 1, - ).to(self.connection.w.device) - + source_s = ( + self.source.s.type(torch.float) + .unfold(-3, kernel_height, stride[0]) + .unfold(-3, kernel_width, stride[1]) + .unfold(-3, kernel_depth, stride[2]) + .reshape( + batch_size, + height_out * width_out * depth_out, + in_channels * kernel_height * kernel_width * kernel_depth, + ) + .repeat( + 1, + out_channels, + 1, + ) + .to(self.connection.w.device) + ) + # print(target_x.shape, source_s.shape) - target_s = self.target.s.type(torch.float).reshape(batch_size, out_channels * height_out * width_out * depth_out,1) - target_s = target_s * torch.eye(out_channels * height_out * width_out * depth_out).to(self.connection.w.device) - + target_s = self.target.s.type(torch.float).reshape( + batch_size, out_channels * height_out * width_out * depth_out, 1 + ) + target_s = target_s * torch.eye( + out_channels * height_out * width_out * depth_out + ).to(self.connection.w.device) + # Update P^+ and P^- values. self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) self.p_plus += a_plus * source_s @@ -2396,9 +2692,9 @@ def _local_connection3d_update(self, **kwargs) -> None: self.p_minus += a_minus * target_s # Calculate point eligibility value. - self.eligibility = torch.bmm( - target_s, self.p_plus - ) + torch.bmm(self.p_minus, source_s) + self.eligibility = torch.bmm(target_s, self.p_plus) + torch.bmm( + self.p_minus, source_s + ) self.eligibility = self.eligibility.view(batch_size, *self.connection.w.shape) super().update() diff --git a/bindsnet/network/nodes.py b/bindsnet/network/nodes.py index c99b42cf..cf8b709c 100644 --- a/bindsnet/network/nodes.py +++ b/bindsnet/network/nodes.py @@ -1212,8 +1212,8 @@ def __init__( self.r = torch.rand(n) self.a = 0.02 * torch.ones(n) self.b = 0.2 * torch.ones(n) - self.c = -65.0 + 15 * (self.r ** 2) - self.d = 8 - 6 * (self.r ** 2) + self.c = -65.0 + 15 * (self.r**2) + self.d = 8 - 6 * (self.r**2) self.S = 0.5 * torch.rand(n, n) self.excitatory = torch.ones(n).byte() @@ -1282,8 +1282,8 @@ def forward(self, x: torch.Tensor) -> None: ) # Apply v and u updates. - self.v += self.dt * 0.5 * (0.04 * self.v ** 2 + 5 * self.v + 140 - self.u + x) - self.v += self.dt * 0.5 * (0.04 * self.v ** 2 + 5 * self.v + 140 - self.u + x) + self.v += self.dt * 0.5 * (0.04 * self.v**2 + 5 * self.v + 140 - self.u + x) + self.v += self.dt * 0.5 * (0.04 * self.v**2 + 5 * self.v + 140 - self.u + x) self.u += self.dt * self.a * (self.b * self.v - self.u) # Voltage clipping to lower bound. @@ -1518,7 +1518,7 @@ def set_batch_size(self, batch_size) -> None: def AlphaKernel(self, dt): t = torch.arange(0, self.res_window_size, dt) - kernelVec = (1 / (self.tau ** 2)) * t * torch.exp(-t / self.tau) + kernelVec = (1 / (self.tau**2)) * t * torch.exp(-t / self.tau) return torch.flip(kernelVec, [0]) def AlphaKernelSLAYER(self, dt): diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index 33f6c79b..34c89527 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -253,6 +253,7 @@ def reset_state_variables(self) -> None: """ super().reset_state_variables() + class Conv1dConnection(AbstractConnection): # language=rst """ @@ -398,6 +399,7 @@ def reset_state_variables(self) -> None: """ super().reset_state_variables() + class Conv2dConnection(AbstractConnection): # language=rst """ @@ -813,7 +815,9 @@ def reset_state_variables(self) -> None: """ super().reset_state_variables() - self.firing_rates = torch.zeros(self.source.batch_size,*(self.source.s.shape[1:])) + self.firing_rates = torch.zeros( + self.source.batch_size, *(self.source.s.shape[1:]) + ) class MaxPool2dConnection(AbstractConnection): @@ -901,7 +905,9 @@ def reset_state_variables(self) -> None: """ super().reset_state_variables() - self.firing_rates = torch.zeros(self.source.batch_size,*(self.source.s.shape[1:])) + self.firing_rates = torch.zeros( + self.source.batch_size, *(self.source.s.shape[1:]) + ) class MaxPoo3dConnection(AbstractConnection): @@ -989,7 +995,9 @@ def reset_state_variables(self) -> None: """ super().reset_state_variables() - self.firing_rates = torch.zeros(self.source.batch_size,*(self.source.s.shape[1:])) + self.firing_rates = torch.zeros( + self.source.batch_size, *(self.source.s.shape[1:]) + ) class LocalConnection(AbstractConnection): @@ -1012,7 +1020,7 @@ def __init__( ) -> None: # language=rst """ - Instantiates a ``LocalConnection2D`` object. Source population should have + Instantiates a ``LocalConnection2D`` object. Source population should have square size Neurons in the post-synaptic population are ordered by receptive field; that is, @@ -1169,9 +1177,10 @@ def reset_state_variables(self) -> None: """ super().reset_state_variables() + class LocalConnection1D(AbstractConnection): """ - Specifies a one-dimensional local connection between one or two population of neurons supporting multi-channel inputs with shape (C, H); + Specifies a one-dimensional local connection between one or two population of neurons supporting multi-channel inputs with shape (C, H); The logic is different from the original LocalConnection implementation (where masks were used with normal dense connections). """ @@ -1185,7 +1194,7 @@ def __init__( nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, - **kwargs + **kwargs, ) -> None: """ Instantiates a 'LocalConnection1D` object. Source population can be multi-channel. @@ -1212,7 +1221,6 @@ def __init__( super().__init__(source, target, nu, reduction, weight_decay, **kwargs) - self.kernel_size = kernel_size self.stride = stride self.n_filters = n_filters @@ -1222,16 +1230,13 @@ def __init__( source.shape[1], ) - height = int(( - input_height - self.kernel_size - ) / self.stride) + 1 - + height = int((input_height - self.kernel_size) / self.stride) + 1 self.conv_size = height w = kwargs.get("w", None) - error = ( + error = ( "Target dimensionality must be (in_channels," "n_filters*conv_size," "kernel_size)" @@ -1239,16 +1244,14 @@ def __init__( if w is None: w = torch.rand( - self.in_channels, - self.n_filters * self.conv_size, - self.kernel_size + self.in_channels, self.n_filters * self.conv_size, self.kernel_size ) else: assert w.shape == ( - self.in_channels, + self.in_channels, self.out_channels * self.conv_size, - self.kernel_size - ), error + self.kernel_size, + ), error if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(w, self.wmin, self.wmax) @@ -1256,7 +1259,6 @@ def __init__( self.w = Parameter(w, requires_grad=False) self.b = Parameter(kwargs.get("b", None), requires_grad=False) - def compute(self, s: torch.Tensor) -> torch.Tensor: """ Compute pre-activations given spikes using layer weights. @@ -1271,20 +1273,22 @@ def compute(self, s: torch.Tensor) -> torch.Tensor: batch_size = s.shape[0] - self.s_unfold = s.unfold( - -1,self.kernel_size,self.stride - ).reshape( - s.shape[0], - self.in_channels, - self.conv_size, - self.kernel_size, - ).repeat( - 1, - 1, - self.n_filters, - 1, + self.s_unfold = ( + s.unfold(-1, self.kernel_size, self.stride) + .reshape( + s.shape[0], + self.in_channels, + self.conv_size, + self.kernel_size, + ) + .repeat( + 1, + 1, + self.n_filters, + 1, + ) ) - + a_post = self.s_unfold.to(self.w.device) * self.w return a_post.sum(-1).sum(1).view(batch_size, *self.target.shape) @@ -1303,13 +1307,10 @@ def normalize(self) -> None: if self.norm is not None: # get a view and modify in-place # w: ch_in, ch_out * h_out, k - w = self.w.view( - self.w.shape[0]*self.w.shape[1], self.w.shape[2] - ) + w = self.w.view(self.w.shape[0] * self.w.shape[1], self.w.shape[2]) for fltr in range(w.shape[0]): - w[fltr,:] *= self.norm / w[fltr,:].sum(0) - + w[fltr, :] *= self.norm / w[fltr, :].sum(0) def reset_state_variables(self) -> None: """ @@ -1319,9 +1320,10 @@ def reset_state_variables(self) -> None: self.target.reset_state_variables() + class LocalConnection2D(AbstractConnection): """ - Specifies a two-dimensional local connection between one or two population of neurons supporting multi-channel inputs with shape (C, H, W); + Specifies a two-dimensional local connection between one or two population of neurons supporting multi-channel inputs with shape (C, H, W); The logic is different from the original LocalConnection implementation (where masks were used with normal dense connections) """ @@ -1335,7 +1337,7 @@ def __init__( nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, - **kwargs + **kwargs, ) -> None: """ Instantiates a 'LocalConnection2D` object. Source population can be multi-channel. @@ -1375,13 +1377,8 @@ def __init__( source.shape[2], ) - height = int(( - input_height - self.kernel_size[0] - ) / self.stride[0]) + 1 - width = int(( - input_width - self.kernel_size[1] - ) / self.stride[1]) + 1 - + height = int((input_height - self.kernel_size[0]) / self.stride[0]) + 1 + width = int((input_width - self.kernel_size[1]) / self.stride[1]) + 1 self.conv_size = (height, width) self.conv_prod = int(np.prod(self.conv_size)) @@ -1389,7 +1386,7 @@ def __init__( w = kwargs.get("w", None) - error = ( + error = ( "Target dimensionality must be (in_channels," "n_filters*conv_prod," "kernel_prod)" @@ -1397,16 +1394,14 @@ def __init__( if w is None: w = torch.rand( - self.in_channels, - self.n_filters * self.conv_prod, - self.kernel_prod + self.in_channels, self.n_filters * self.conv_prod, self.kernel_prod ) else: assert w.shape == ( - self.in_channels, + self.in_channels, self.out_channels * self.conv_prod, - self.kernel_prod - ), error + self.kernel_prod, + ), error if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(w, self.wmin, self.wmax) @@ -1414,7 +1409,6 @@ def __init__( self.w = Parameter(w, requires_grad=False) self.b = Parameter(kwargs.get("b", None), requires_grad=False) - def compute(self, s: torch.Tensor) -> torch.Tensor: """ Compute pre-activations given spikes using layer weights. @@ -1429,22 +1423,23 @@ def compute(self, s: torch.Tensor) -> torch.Tensor: batch_size = s.shape[0] - self.s_unfold = s.unfold( - -2,self.kernel_size[0],self.stride[0] - ).unfold( - -2,self.kernel_size[1],self.stride[1] - ).reshape( - s.shape[0], - self.in_channels, - self.conv_prod, - self.kernel_prod, - ).repeat( - 1, - 1, - self.n_filters, - 1, + self.s_unfold = ( + s.unfold(-2, self.kernel_size[0], self.stride[0]) + .unfold(-2, self.kernel_size[1], self.stride[1]) + .reshape( + s.shape[0], + self.in_channels, + self.conv_prod, + self.kernel_prod, + ) + .repeat( + 1, + 1, + self.n_filters, + 1, + ) ) - + a_post = self.s_unfold.to(self.w.device) * self.w return a_post.sum(-1).sum(1).view(batch_size, *self.target.shape) @@ -1463,13 +1458,10 @@ def normalize(self) -> None: if self.norm is not None: # get a view and modify in-place # w: ch_in, ch_out * w_out * h_out, k1 * k2 - w = self.w.view( - self.w.shape[0]*self.w.shape[1], self.w.shape[2] - ) + w = self.w.view(self.w.shape[0] * self.w.shape[1], self.w.shape[2]) for fltr in range(w.shape[0]): - w[fltr,:] *= self.norm / w[fltr,:].sum(0) - + w[fltr, :] *= self.norm / w[fltr, :].sum(0) def reset_state_variables(self) -> None: """ @@ -1496,7 +1488,7 @@ def __init__( nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, - **kwargs + **kwargs, ) -> None: """ Instantiates a 'LocalConnection3D` object. Source population can be multi-channel. @@ -1534,19 +1526,12 @@ def __init__( source.shape[0], source.shape[1], source.shape[2], - source.shape[3] + source.shape[3], ) - height = int(( - input_height - self.kernel_size[0] - ) / self.stride[0]) + 1 - width = int(( - input_width - self.kernel_size[1] - ) / self.stride[1]) + 1 - depth = int(( - input_depth - self.kernel_size[2] - ) / self.stride[2]) + 1 - + height = int((input_height - self.kernel_size[0]) / self.stride[0]) + 1 + width = int((input_width - self.kernel_size[1]) / self.stride[1]) + 1 + depth = int((input_depth - self.kernel_size[2]) / self.stride[2]) + 1 self.conv_size = (height, width, depth) self.conv_prod = int(np.prod(self.conv_size)) @@ -1554,7 +1539,7 @@ def __init__( w = kwargs.get("w", None) - error = ( + error = ( "Target dimensionality must be (in_channels," "n_filters*conv_prod," "kernel_prod)" @@ -1562,16 +1547,14 @@ def __init__( if w is None: w = torch.rand( - self.in_channels, - self.n_filters * self.conv_prod, - self.kernel_prod + self.in_channels, self.n_filters * self.conv_prod, self.kernel_prod ) else: assert w.shape == ( - self.in_channels, + self.in_channels, self.out_channels * self.conv_prod, - self.kernel_prod - ), error + self.kernel_prod, + ), error if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(w, self.wmin, self.wmax) @@ -1579,7 +1562,6 @@ def __init__( self.w = Parameter(w, requires_grad=False) self.b = Parameter(kwargs.get("b", None), requires_grad=False) - def compute(self, s: torch.Tensor) -> torch.Tensor: """ Compute pre-activations given spikes using layer weights. @@ -1594,24 +1576,24 @@ def compute(self, s: torch.Tensor) -> torch.Tensor: batch_size = s.shape[0] - self.s_unfold = s.unfold( - -3,self.kernel_size[0],self.stride[0] - ).unfold( - -3,self.kernel_size[1],self.stride[1] - ).unfold( - -3,self.kernel_size[2],self.stride[2] - ).reshape( - s.shape[0], - self.in_channels, - self.conv_prod, - self.kernel_prod, - ).repeat( - 1, - 1, - self.n_filters, - 1, + self.s_unfold = ( + s.unfold(-3, self.kernel_size[0], self.stride[0]) + .unfold(-3, self.kernel_size[1], self.stride[1]) + .unfold(-3, self.kernel_size[2], self.stride[2]) + .reshape( + s.shape[0], + self.in_channels, + self.conv_prod, + self.kernel_prod, + ) + .repeat( + 1, + 1, + self.n_filters, + 1, + ) ) - + a_post = self.s_unfold.to(self.w.device) * self.w return a_post.sum(-1).sum(1).view(batch_size, *self.target.shape) @@ -1630,13 +1612,10 @@ def normalize(self) -> None: if self.norm is not None: # get a view and modify in-place # w: ch_in, ch_out * w_out * h_out * d_out, k1*k2*k3 - w = self.w.view( - self.w.shape[0]*self.w.shape[1], self.w.shape[2] - ) + w = self.w.view(self.w.shape[0] * self.w.shape[1], self.w.shape[2]) for fltr in range(w.shape[0]): - w[fltr,:] *= self.norm / w[fltr,:].sum(0) - + w[fltr, :] *= self.norm / w[fltr, :].sum(0) def reset_state_variables(self) -> None: """ @@ -1645,7 +1624,8 @@ def reset_state_variables(self) -> None: super().reset_state_variables() self.target.reset_state_variables() - + + class MeanFieldConnection(AbstractConnection): # language=rst """ diff --git a/bindsnet/utils.py b/bindsnet/utils.py index 628bf17d..d1b00257 100644 --- a/bindsnet/utils.py +++ b/bindsnet/utils.py @@ -246,8 +246,7 @@ def reshape_local_connection_2d_weights( for n2 in range(c2): for feature in range(n_filters): n = n1 * c2 + n2 - filter_ = w[feature, n1, n2, :, : - ].view(k1, k2) + filter_ = w[feature, n1, n2, :, :].view(k1, k2) w_[feature * k1 : (feature + 1) * k1, n * k2 : (n + 1) * k2] = filter_ if c1 == 1 and c2 == 1: @@ -276,4 +275,4 @@ def reshape_local_connection_2d_weights( (n1 * c2 + n2) * k2 : (n1 * c2 + n2 + 1) * k2, ] - return square \ No newline at end of file + return square diff --git a/examples/mnist/conv3d_MNIST.py b/examples/mnist/conv3d_MNIST.py index be0b280e..8a8e4c50 100644 --- a/examples/mnist/conv3d_MNIST.py +++ b/examples/mnist/conv3d_MNIST.py @@ -94,7 +94,7 @@ kernel_size=kernel_size, stride=stride, update_rule=PostPre, - norm=0.4 * kernel_size ** 3, + norm=0.4 * kernel_size**3, nu=[1e-4, 1e-2], wmax=1.0, ) diff --git a/examples/mnist/conv_mnist.py b/examples/mnist/conv_mnist.py index a23f0b4b..f91fd5c6 100644 --- a/examples/mnist/conv_mnist.py +++ b/examples/mnist/conv_mnist.py @@ -98,7 +98,7 @@ kernel_size=kernel_size, stride=stride, update_rule=PostPre, - norm=0.4 * kernel_size ** 2, + norm=0.4 * kernel_size**2, nu=[1e-4, 1e-2], wmax=1.0, ) diff --git a/examples/mnist/loc1d_mnist.py b/examples/mnist/loc1d_mnist.py index c96cb0cc..403a69ca 100644 --- a/examples/mnist/loc1d_mnist.py +++ b/examples/mnist/loc1d_mnist.py @@ -29,11 +29,11 @@ in_channels = 1 n_filters = 25 input_shape = 784 -kernel_size = 28*2 +kernel_size = 28 * 2 stride = 28 tc_theta_decay = 1e6 theta_plus = 0.05 -norm = 0.2*kernel_size +norm = 0.2 * kernel_size wmin = 0.0 wmax = 1.0 nu = (1e-4, 1e-2) @@ -49,11 +49,7 @@ # Build network network = Network() -input_layer = Input( - shape=[in_channels, input_shape], - traces=True, - tc_trace=20 -) +input_layer = Input(shape=[in_channels, input_shape], traces=True, tc_trace=20) compute_conv_size = lambda inp_size, k, s: int((inp_size - k) / s) + 1 conv_size = compute_conv_size(input_shape, kernel_size, stride) @@ -76,7 +72,7 @@ output_layer, kernel_size=kernel_size, stride=stride, - n_filters = n_filters, + n_filters=n_filters, nu=nu, update_rule=PostPre, wmin=wmin, @@ -87,8 +83,8 @@ w_inh_LC = torch.zeros(n_filters, conv_size, n_filters, conv_size) for c in range(n_filters): for w1 in range(conv_size): - w_inh_LC[c, w1, :, w1] = -inh - w_inh_LC[c, w1, c, w1] = 0 + w_inh_LC[c, w1, :, w1] = -inh + w_inh_LC[c, w1, c, w1] = 0 w_inh_LC = w_inh_LC.reshape(output_layer.n, output_layer.n) recurrent_conn = Connection(output_layer, output_layer, w=w_inh_LC) @@ -148,7 +144,11 @@ start = t() train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=gpu + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=0, + pin_memory=gpu, ) for step, batch in enumerate(tqdm(train_dataloader)): @@ -162,7 +162,6 @@ # Run the network on the input. network.run(inputs=inputs, time=time, input_time_dim=1) - network.reset_state_variables() # Reset state variables. diff --git a/examples/mnist/loc2d_mnist.py b/examples/mnist/loc2d_mnist.py index 7acd29f8..157cd401 100644 --- a/examples/mnist/loc2d_mnist.py +++ b/examples/mnist/loc2d_mnist.py @@ -1,4 +1,3 @@ - import torch from torch.nn.modules.utils import _pair @@ -34,10 +33,10 @@ stride = _pair(4) tc_theta_decay = 1e6 theta_plus = 0.05 -norm = 0.2*kernel_size[0]*kernel_size[1] +norm = 0.2 * kernel_size[0] * kernel_size[1] wmin = 0.0 wmax = 1.0 -nu = (0.0001,0.01) +nu = (0.0001, 0.01) inh = 25.0 dt = 1.0 time = 250 @@ -53,9 +52,7 @@ network = Network() input_layer = Input( - shape=[in_channels, input_shape[0], input_shape[1]], - traces=True, - tc_trace=20 + shape=[in_channels, input_shape[0], input_shape[1]], traces=True, tc_trace=20 ) compute_conv_size = lambda inp_size, k, s: int((inp_size - k) / s) + 1 @@ -78,7 +75,7 @@ output_layer, kernel_size=kernel_size, stride=stride, - n_filters = n_filters, + n_filters=n_filters, nu=nu, update_rule=PostPre, wmin=wmin, @@ -86,7 +83,9 @@ norm=norm, ) -w_inh_LC = torch.zeros(n_filters, conv_size[0], conv_size[1], n_filters, conv_size[0], conv_size[1]) +w_inh_LC = torch.zeros( + n_filters, conv_size[0], conv_size[1], n_filters, conv_size[0], conv_size[1] +) for c in range(n_filters): for w1 in range(conv_size[0]): for w2 in range(conv_size[0]): @@ -126,7 +125,11 @@ download=True, train=True, transform=transforms.Compose( - [transforms.ToTensor(), transforms.CenterCrop((input_shape[0], input_shape[1])), transforms.Lambda(lambda x: x * intensity)] + [ + transforms.ToTensor(), + transforms.CenterCrop((input_shape[0], input_shape[1])), + transforms.Lambda(lambda x: x * intensity), + ] ), ) @@ -152,14 +155,22 @@ start = t() train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=gpu + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=0, + pin_memory=gpu, ) for step, batch in enumerate(tqdm(train_dataloader)): # Get next input sample. if step > n_train: break - inputs = {"X": batch["encoded_image"].view(time, batch_size, 1, input_shape[0], input_shape[1])} + inputs = { + "X": batch["encoded_image"].view( + time, batch_size, 1, input_shape[0], input_shape[1] + ) + } if gpu: inputs = {k: v.cuda() for k, v in inputs.items()} label = batch["label"] @@ -169,7 +180,9 @@ # Optionally plot various simulation information. if plot: - weights1_im = plot_local_connection_2d_weights(network.connections[("X", "Y")], im=weights1_im) + weights1_im = plot_local_connection_2d_weights( + network.connections[("X", "Y")], im=weights1_im + ) plt.pause(1) network.reset_state_variables() # Reset state variables. @@ -178,5 +191,5 @@ print("Training complete.\n") weights1_im = plot_local_connection_2d_weights(network.connections[("X", "Y")]) -plt.savefig('test.png') -plt.pause(100) \ No newline at end of file +plt.savefig("test.png") +plt.pause(100) diff --git a/examples/mnist/loc3d_mnist.py b/examples/mnist/loc3d_mnist.py index 4cf6c18b..49e11f46 100644 --- a/examples/mnist/loc3d_mnist.py +++ b/examples/mnist/loc3d_mnist.py @@ -34,10 +34,10 @@ stride = _triple(2) tc_theta_decay = 1e6 theta_plus = 0.05 -norm = 0.2*kernel_size[0]*kernel_size[1]*kernel_size[2] +norm = 0.2 * kernel_size[0] * kernel_size[1] * kernel_size[2] wmin = 0.0 wmax = 1.0 -nu = (0.0001,0.01) +nu = (0.0001, 0.01) inh = 25.0 dt = 1.0 time = 250 @@ -50,7 +50,11 @@ # Build network network = Network() -input_layer = Input(n=input_shape[0]*input_shape[1]*input_shape[2], shape=(in_channels, input_shape[0], input_shape[1], input_shape[2]), traces=True) +input_layer = Input( + n=input_shape[0] * input_shape[1] * input_shape[2], + shape=(in_channels, input_shape[0], input_shape[1], input_shape[2]), + traces=True, +) compute_conv_size = lambda inp_size, k, s: int((inp_size - k) / s) + 1 conv_size = _triple(compute_conv_size(input_shape[0], kernel_size[0], stride[0])) @@ -72,7 +76,7 @@ output_layer, kernel_size=kernel_size, stride=stride, - n_filters = n_filters, + n_filters=n_filters, nu=nu, update_rule=PostPre, wmin=wmin, @@ -80,7 +84,16 @@ norm=norm, ) -w_inh_LC = torch.zeros(n_filters, conv_size[0], conv_size[1], conv_size[2], n_filters, conv_size[0], conv_size[1], conv_size[2]) +w_inh_LC = torch.zeros( + n_filters, + conv_size[0], + conv_size[1], + conv_size[2], + n_filters, + conv_size[0], + conv_size[1], + conv_size[2], +) for c in range(n_filters): for w1 in range(conv_size[0]): @@ -122,7 +135,11 @@ download=True, train=True, transform=transforms.Compose( - [transforms.ToTensor(), transforms.CenterCrop((input_shape[0], input_shape[1])), transforms.Lambda(lambda x: x * intensity)] + [ + transforms.ToTensor(), + transforms.CenterCrop((input_shape[0], input_shape[1])), + transforms.Lambda(lambda x: x * intensity), + ] ), ) @@ -146,7 +163,11 @@ start = t() train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=gpu + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=0, + pin_memory=gpu, ) for step, batch in enumerate(tqdm(train_dataloader)): @@ -166,7 +187,7 @@ # Run the network on the input. network.run(inputs=inputs, time=time, input_time_dim=1) - + print("Progress: %d / %d (%.4f seconds)\n" % (n_epochs, n_epochs, t() - start)) -print("Training complete.\n") \ No newline at end of file +print("Training complete.\n") diff --git a/examples/tensorboard/tensorboard.py b/examples/tensorboard/tensorboard.py index 127b1e52..a35a7e8d 100644 --- a/examples/tensorboard/tensorboard.py +++ b/examples/tensorboard/tensorboard.py @@ -87,7 +87,7 @@ kernel_size=kernel_size, stride=stride, update_rule=PostPre, - norm=0.4 * kernel_size ** 2, + norm=0.4 * kernel_size**2, nu=[1e-4, 1e-2], wmax=1.0, )