Skip to content

Commit

Permalink
New local connection classes (1D, 2D, and 3D)
Browse files Browse the repository at this point in the history
Apply fix for #537
  • Loading branch information
Hananel-Hazan committed Feb 13, 2022
1 parent 6a3b80e commit 9049bc6
Show file tree
Hide file tree
Showing 12 changed files with 915 additions and 592 deletions.
39 changes: 27 additions & 12 deletions bindsnet/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from mpl_toolkits.axes_grid1 import make_axes_locatable
from torch.nn.modules.utils import _pair

from bindsnet.utils import reshape_conv2d_weights, reshape_locally_connected_weights, reshape_local_connection_2d_weights
from bindsnet.utils import (
reshape_conv2d_weights,
reshape_locally_connected_weights,
reshape_local_connection_2d_weights,
)

plt.ion()

Expand Down Expand Up @@ -377,15 +381,17 @@ def plot_locally_connected_weights(

return im

def plot_local_connection_2d_weights(lc : object,

def plot_local_connection_2d_weights(
lc: object,
input_channel: int = 0,
output_channel: int = None,
im: Optional[AxesImage] = None,
lines: bool = True,
figsize: Tuple[int, int] = (5, 5),
cmap: str = "hot_r",
color: str='r',
) -> AxesImage:
color: str = "r",
) -> AxesImage:
# language=rst
"""
Plot a connection weight matrix of a :code:`Connection` with `locally connected
Expand All @@ -400,23 +406,34 @@ def plot_local_connection_2d_weights(lc : object,
"""

n_sqrt = int(np.ceil(np.sqrt(lc.n_filters)))
sel_slice = lc.w.view(lc.in_channels, lc.n_filters, lc.conv_size[0], lc.conv_size[1], lc.kernel_size[0], lc.kernel_size[1]).cpu()
sel_slice = lc.w.view(
lc.in_channels,
lc.n_filters,
lc.conv_size[0],
lc.conv_size[1],
lc.kernel_size[0],
lc.kernel_size[1],
).cpu()
input_size = _pair(int(np.sqrt(lc.source.n)))
if output_channel is None:
sel_slice = sel_slice[input_channel, ...]
reshaped = reshape_local_connection_2d_weights(sel_slice, lc.n_filters, lc.kernel_size, lc.conv_size, input_size)
reshaped = reshape_local_connection_2d_weights(
sel_slice, lc.n_filters, lc.kernel_size, lc.conv_size, input_size
)
else:
sel_slice = sel_slice[input_channel, output_channel, ...]
sel_slice = sel_slice.unsqueeze(0)
reshaped = reshape_local_connection_2d_weights(sel_slice, 1, lc.kernel_size, lc.conv_size, input_size)
reshaped = reshape_local_connection_2d_weights(
sel_slice, 1, lc.kernel_size, lc.conv_size, input_size
)
if im == None:
fig, ax = plt.subplots(figsize=figsize)

im = ax.imshow(reshaped.cpu(), cmap=cmap, vmin=lc.wmin, vmax=lc.wmax)
div = make_axes_locatable(ax)
cax = div.append_axes("right", size="5%", pad=0.05)

if lines and output_channel is None:
if lines and output_channel is None:
for i in range(
n_sqrt * lc.kernel_size[0],
n_sqrt * lc.conv_size[0] * lc.kernel_size[0],
Expand All @@ -430,7 +447,7 @@ def plot_local_connection_2d_weights(lc : object,
n_sqrt * lc.kernel_size[1],
):
ax.axvline(i - 0.5, color=color, linestyle="--")

ax.set_xticks(())
ax.set_yticks(())
ax.set_aspect("auto")
Expand All @@ -441,7 +458,7 @@ def plot_local_connection_2d_weights(lc : object,
else:
im.set_data(reshaped.cpu())
return im


def plot_assignments(
assignments: torch.Tensor,
Expand Down Expand Up @@ -825,5 +842,3 @@ def plot_voltages(
plt.tight_layout()

return ims, axes


2 changes: 1 addition & 1 deletion bindsnet/datasets/spoken_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def process_data(
# Fast Fourier Transform and Power Spectrum
NFFT = 512
mag_frames = np.absolute(np.fft.rfft(frames, NFFT)) # Magnitude of the FFT
pow_frames = (1.0 / NFFT) * (mag_frames ** 2) # Power Spectrum
pow_frames = (1.0 / NFFT) * (mag_frames**2) # Power Spectrum

# Log filter banks
nfilt = 40
Expand Down
Loading

0 comments on commit 9049bc6

Please sign in to comment.