From efed03e643fa3269dec33ee492d891782fa9fa61 Mon Sep 17 00:00:00 2001 From: Sahil Chhoker Date: Tue, 28 Jan 2025 14:02:44 +0530 Subject: [PATCH] Fix: Property layer visualization for HexGrid (#2646) --- mesa/visualization/mpl_space_drawing.py | 142 +++++++++++++++--------- 1 file changed, 89 insertions(+), 53 deletions(-) diff --git a/mesa/visualization/mpl_space_drawing.py b/mesa/visualization/mpl_space_drawing.py index cc8d59cdd68..3ddd9098ecf 100644 --- a/mesa/visualization/mpl_space_drawing.py +++ b/mesa/visualization/mpl_space_drawing.py @@ -9,7 +9,8 @@ import contextlib import itertools import warnings -from collections.abc import Callable +from collections.abc import Callable, Iterator +from functools import lru_cache from itertools import pairwise from typing import Any @@ -18,7 +19,7 @@ from matplotlib import pyplot as plt from matplotlib.axes import Axes from matplotlib.cm import ScalarMappable -from matplotlib.collections import LineCollection, PatchCollection +from matplotlib.collections import LineCollection, PatchCollection, PolyCollection from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba from matplotlib.patches import Polygon @@ -159,6 +160,37 @@ def draw_space( return ax +@lru_cache(maxsize=1024, typed=True) +def _get_hexmesh( + width: int, height: int, size: float = 1.0 +) -> Iterator[list[tuple[float, float]]]: + """Generate hexagon vertices for the mesh. Yields list of vertex coordinates for each hexagon.""" + + # Helper function for getting the vertices of a hexagon given the center and size + def _get_hex_vertices( + center_x: float, center_y: float, size: float = 1.0 + ) -> list[tuple[float, float]]: + """Get vertices for a hexagon centered at (center_x, center_y).""" + vertices = [ + (center_x, center_y + size), # top + (center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right + (center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right + (center_x, center_y - size), # bottom + (center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left + (center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left + ] + return vertices + + x_spacing = np.sqrt(3) * size + y_spacing = 1.5 * size + + for row, col in itertools.product(range(height), range(width)): + # Calculate center position with offset for even rows + x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2) + y = row * y_spacing + yield _get_hex_vertices(x, y, size) + + def draw_property_layers( space, propertylayer_portrayal: dict[str, dict[str, Any]], ax: Axes ): @@ -205,46 +237,74 @@ def draw_property_layers( vmax = portrayal.get("vmax", np.max(data)) colorbar = portrayal.get("colorbar", True) - # Draw the layer + # Prepare colormap if "color" in portrayal: - data = data.T rgba_color = to_rgba(portrayal["color"]) - normalized_data = (data - vmin) / (vmax - vmin) - rgba_data = np.full((*data.shape, 4), rgba_color) - rgba_data[..., 3] *= normalized_data * alpha - rgba_data = np.clip(rgba_data, 0, 1) cmap = LinearSegmentedColormap.from_list( layer_name, [(0, 0, 0, 0), (*rgba_color[:3], alpha)] ) - im = ax.imshow( - rgba_data, - origin="lower", - ) - if colorbar: - norm = Normalize(vmin=vmin, vmax=vmax) - sm = ScalarMappable(norm=norm, cmap=cmap) - sm.set_array([]) - ax.figure.colorbar(sm, ax=ax, orientation="vertical") - elif "colormap" in portrayal: cmap = portrayal.get("colormap", "viridis") if isinstance(cmap, list): cmap = LinearSegmentedColormap.from_list(layer_name, cmap) - im = ax.imshow( - data.T, - cmap=cmap, - alpha=alpha, - vmin=vmin, - vmax=vmax, - origin="lower", - ) - if colorbar: - plt.colorbar(im, ax=ax, label=layer_name) + elif isinstance(cmap, str): + cmap = plt.get_cmap(cmap) else: raise ValueError( f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'." ) + if isinstance(space, OrthogonalGrid): + if "color" in portrayal: + data = data.T + normalized_data = (data - vmin) / (vmax - vmin) + rgba_data = np.full((*data.shape, 4), rgba_color) + rgba_data[..., 3] *= normalized_data * alpha + rgba_data = np.clip(rgba_data, 0, 1) + ax.imshow(rgba_data, origin="lower") + else: + ax.imshow( + data.T, + cmap=cmap, + alpha=alpha, + vmin=vmin, + vmax=vmax, + origin="lower", + ) + + elif isinstance(space, HexGrid): + width, height = data.shape + + # Generate hexagon mesh + hexagons = _get_hexmesh(width, height) + + # Normalize colors + norm = Normalize(vmin=vmin, vmax=vmax) + colors = data.ravel() # flatten data to 1D array + + if "color" in portrayal: + normalized_colors = np.clip(norm(colors), 0, 1) + rgba_colors = np.full((len(colors), 4), rgba_color) + rgba_colors[:, 3] = normalized_colors * alpha + else: + rgba_colors = cmap(norm(colors)) + + # Draw hexagons + collection = PolyCollection(hexagons, facecolors=rgba_colors, zorder=-1) + ax.add_collection(collection) + + else: + raise NotImplementedError( + f"PropertyLayer visualization not implemented for {type(space)}." + ) + + # Add colorbar if requested + if colorbar: + norm = Normalize(vmin=vmin, vmax=vmax) + sm = ScalarMappable(norm=norm, cmap=cmap) + sm.set_array([]) + plt.colorbar(sm, ax=ax, label=layer_name) + def draw_orthogonal_grid( space: OrthogonalGrid, @@ -349,39 +409,15 @@ def draw_hex_grid( def setup_hexmesh(width, height): """Helper function for creating the hexmesh with unique edges.""" edges = set() - size = 1.0 - x_spacing = np.sqrt(3) * size - y_spacing = 1.5 * size - - def get_hex_vertices( - center_x: float, center_y: float - ) -> list[tuple[float, float]]: - """Get vertices for a hexagon centered at (center_x, center_y).""" - vertices = [ - (center_x, center_y + size), # top - (center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right - (center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right - (center_x, center_y - size), # bottom - (center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left - (center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left - ] - return vertices # Generate edges for each hexagon - for row, col in itertools.product(range(height), range(width)): - # Calculate center position for each hexagon with offset for even rows - x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2) - y = row * y_spacing - - vertices = get_hex_vertices(x, y) - + for vertices in _get_hexmesh(width, height): # Edge logic, connecting each vertex to the next for v1, v2 in pairwise([*vertices, vertices[0]]): # Sort vertices to ensure consistent edge representation and avoid duplicates. edge = tuple(sorted([tuple(np.round(v1, 6)), tuple(np.round(v2, 6))])) edges.add(edge) - # Return LineCollection for hexmesh return LineCollection(edges, linestyle=":", color="black", linewidth=1, alpha=1) if draw_grid: