Skip to content

Commit

Permalink
Merge pull request #536 from hafezgh/localconnections
Browse files Browse the repository at this point in the history
Localconnections
  • Loading branch information
Hananel-Hazan authored Feb 13, 2022
2 parents b850dd1 + 9049bc6 commit 12d060c
Show file tree
Hide file tree
Showing 7 changed files with 2,568 additions and 26 deletions.
84 changes: 83 additions & 1 deletion bindsnet/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from mpl_toolkits.axes_grid1 import make_axes_locatable
from torch.nn.modules.utils import _pair

from bindsnet.utils import reshape_conv2d_weights, reshape_locally_connected_weights
from bindsnet.utils import (
reshape_conv2d_weights,
reshape_locally_connected_weights,
reshape_local_connection_2d_weights,
)

plt.ion()

Expand Down Expand Up @@ -378,6 +382,84 @@ def plot_locally_connected_weights(
return im


def plot_local_connection_2d_weights(
lc: object,
input_channel: int = 0,
output_channel: int = None,
im: Optional[AxesImage] = None,
lines: bool = True,
figsize: Tuple[int, int] = (5, 5),
cmap: str = "hot_r",
color: str = "r",
) -> AxesImage:
# language=rst
"""
Plot a connection weight matrix of a :code:`Connection` with `locally connected
structure <http://yann.lecun.com/exdb/publis/pdf/gregor-nips-11.pdf>_.
:param lc: An object of the class LocalConnection2D
:param input_channel: The input channel to plot its corresponding weights, default is the first channel
:param output_channel: If not None, will only plot the weights corresponding to this output channel (filter)
:param lines: Indicates whether or not draw horizontal and vertical lines separating input regions.
:param figsize: Horizontal and vertical figure size in inches.
:param cmap: Matplotlib colormap.
:return: ``ims, axes``: Used for re-drawing the plots.
"""

n_sqrt = int(np.ceil(np.sqrt(lc.n_filters)))
sel_slice = lc.w.view(
lc.in_channels,
lc.n_filters,
lc.conv_size[0],
lc.conv_size[1],
lc.kernel_size[0],
lc.kernel_size[1],
).cpu()
input_size = _pair(int(np.sqrt(lc.source.n)))
if output_channel is None:
sel_slice = sel_slice[input_channel, ...]
reshaped = reshape_local_connection_2d_weights(
sel_slice, lc.n_filters, lc.kernel_size, lc.conv_size, input_size
)
else:
sel_slice = sel_slice[input_channel, output_channel, ...]
sel_slice = sel_slice.unsqueeze(0)
reshaped = reshape_local_connection_2d_weights(
sel_slice, 1, lc.kernel_size, lc.conv_size, input_size
)
if im == None:
fig, ax = plt.subplots(figsize=figsize)

im = ax.imshow(reshaped.cpu(), cmap=cmap, vmin=lc.wmin, vmax=lc.wmax)
div = make_axes_locatable(ax)
cax = div.append_axes("right", size="5%", pad=0.05)

if lines and output_channel is None:
for i in range(
n_sqrt * lc.kernel_size[0],
n_sqrt * lc.conv_size[0] * lc.kernel_size[0],
n_sqrt * lc.kernel_size[0],
):
ax.axhline(i - 0.5, color=color, linestyle="--")

for i in range(
n_sqrt * lc.kernel_size[1],
n_sqrt * lc.conv_size[1] * lc.kernel_size[1],
n_sqrt * lc.kernel_size[1],
):
ax.axvline(i - 0.5, color=color, linestyle="--")

ax.set_xticks(())
ax.set_yticks(())
ax.set_aspect("auto")

plt.colorbar(im, cax=cax)
fig.tight_layout()

else:
im.set_data(reshaped.cpu())
return im


def plot_assignments(
assignments: torch.Tensor,
im: Optional[AxesImage] = None,
Expand Down
Loading

0 comments on commit 12d060c

Please sign in to comment.