Skip to content

Commit

Permalink
added property_layer with altair
Browse files Browse the repository at this point in the history
  • Loading branch information
sanika-n committed Jan 25, 2025
1 parent 13518b2 commit 4492a42
Showing 1 changed file with 133 additions and 11 deletions.
144 changes: 133 additions & 11 deletions mesa/visualization/components/altair_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@
with contextlib.suppress(ImportError):
import altair as alt

import numpy as np
import pandas as pd
from matplotlib.colors import to_rgba

import mesa.experimental
from mesa.experimental.cell_space import DiscreteSpace, Grid
from mesa.space import ContinuousSpace, _Grid
from mesa.space import ContinuousSpace, PropertyLayer, _Grid
from mesa.visualization.utils import update_counter


Expand Down Expand Up @@ -46,13 +51,18 @@ def agent_portrayal(a):
return {"id": a.unique_id}

def MakeSpaceAltair(model):
return SpaceAltair(model, agent_portrayal)
return SpaceAltair(model, agent_portrayal, propertylayer_portrayal)

return MakeSpaceAltair


@solara.component
def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None):
def SpaceAltair(
model,
agent_portrayal,
propertylayer_portrayal,
dependencies: list[any] | None = None,
):
"""Create an Altair-based space visualization component.
Returns:
Expand All @@ -64,7 +74,7 @@ def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None):
# Sometimes the space is defined as model.space instead of model.grid
space = model.space

chart = _draw_grid(space, agent_portrayal)
chart = _draw_grid(space, agent_portrayal, propertylayer_portrayal)
solara.FigureAltair(chart)


Expand Down Expand Up @@ -136,7 +146,7 @@ def _get_agent_data_continuous_space(space: ContinuousSpace, agent_portrayal):
return all_agent_data


def _draw_grid(space, agent_portrayal):
def _draw_grid(space, agent_portrayal, propertylayer_portrayal):
match space:
case Grid():
all_agent_data = _get_agent_data_new_discrete_space(space, agent_portrayal)
Expand All @@ -159,30 +169,142 @@ def _draw_grid(space, agent_portrayal):
# no y-axis label
"y": alt.Y("y", axis=None, type=x_y_type),
"tooltip": [
alt.Tooltip(key, type=alt.utils.infer_vegalite_type([value]))
alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas(value))
for key, value in all_agent_data[0].items()
if key not in invalid_tooltips
],
}
has_color = "color" in all_agent_data[0]
if has_color:
encoding_dict["color"] = alt.Color("color", type="nominal")
unique_colors = list({agent["color"] for agent in all_agent_data})
encoding_dict["color"] = alt.Color(
"color:N",
scale=alt.Scale(domain=unique_colors, range=unique_colors),
legend=None,
)
has_size = "size" in all_agent_data[0]
if has_size:
encoding_dict["size"] = alt.Size("size", type="quantitative")
encoding_dict["size"] = alt.Size("size", type="quantitative", legend=None)

chart = (
agent_chart = (
alt.Chart(
alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict)
)
.mark_point(filled=True)
.properties(width=280, height=280)
.properties(width=300, height=300)
# .configure_view(strokeOpacity=0) # hide grid/chart lines
)
# This is the default value for the marker size, which auto-scales
# according to the grid area.
if not has_size:
length = min(space.width, space.height)
chart = chart.mark_point(size=30000 / length**2, filled=True)
chart = agent_chart.mark_point(size=30000 / length**2, filled=True)

if propertylayer_portrayal is not None:
base_width = agent_chart.properties().width
base_height = agent_chart.properties().height
chart = chart_property_layers(
space=space,
propertylayer_portrayal=propertylayer_portrayal,
base_width=base_width,
base_height=base_height,
)

chart = chart + agent_chart
return chart


def chart_property_layers(space, propertylayer_portrayal, base_width, base_height):
"""Creates Property Layers in the Altair Components.
Args:
space: the ContinuousSpace instance
propertylayer_portrayal:Dictionary of PropertyLayer portrayal specifications
base_width: width of the agent chart to maintain consistency with the property charts
base_height: height of the agent chart to maintain consistency with the property charts
Returns:
Altair Chart
"""
try:
# old style spaces
property_layers = space.properties
except AttributeError:
# new style spaces
property_layers = space._mesa_property_layers
base = None
for layer_name, portrayal in propertylayer_portrayal.items():
layer = property_layers.get(layer_name, None)
if not isinstance(
layer,
PropertyLayer | mesa.experimental.cell_space.property_layer.PropertyLayer,
):
continue

data = layer.data.astype(float) if layer.data.dtype == bool else layer.data

if (space.width, space.height) is not data.shape:
warnings.warn(
f"Layer {layer_name} dimensions ({data.shape}) do not match space dimensions ({space.width}, {space.height}).",
UserWarning,
stacklevel=2,
)
alpha = portrayal.get("alpha", 1)
vmin = portrayal.get("vmin", np.min(data))
vmax = portrayal.get("vmax", np.max(data))
colorbar = portrayal.get("colorbar", True)

# Prepare data for Altair (convert 2D array to a long-form DataFrame)
df = pd.DataFrame(
{
"x": np.repeat(np.arange(data.shape[0]), data.shape[1]),
"y": np.tile(np.arange(data.shape[1]), data.shape[0]),
"value": data.flatten(),
}
)

# Add RGBA color if "color" is in portrayal
if "color" in portrayal:
df["color"] = df["value"].apply(
lambda val,
portrayal=portrayal,
alpha=alpha: f"rgba({int(to_rgba(portrayal['color'], alpha=alpha)[0] * 255)}, {int(to_rgba(portrayal['color'], alpha=alpha)[1] * 255)}, {int(to_rgba(portrayal['color'], alpha=alpha)[2] * 255)}, {to_rgba(portrayal['color'], alpha=alpha)[3]:.2f})"
if val > 0
else "rgba(0, 0, 0, 0)"
)
chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("x:O", axis=None),
y=alt.Y("y:O", axis=None),
color=alt.Color("color:N", legend=None),
)
.properties(width=base_width, height=base_height, title=layer_name)
)
base = (base + chart) if base is not None else chart
# Add colormap if "colormap" is in portrayal
elif "colormap" in portrayal:
cmap = portrayal.get("colormap", "viridis")
cmap_scale = alt.Scale(scheme=cmap, domain=[vmin, vmax])

chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("x:O", axis=None),
y=alt.Y("y:O", axis=None),
color=alt.Color(
"value:Q",
scale=cmap_scale,
title=layer_name if colorbar else None,
),
)
.properties(width=base_width, height=base_height, title=layer_name)
)
base = (base + chart) if base is not None else chart

else:
raise ValueError(
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
)
return chart

0 comments on commit 4492a42

Please sign in to comment.