Skip to content

Commit

Permalink
ENH: add Moran scatterplot ported from splot (#356)
Browse files Browse the repository at this point in the history
* implementation

* docstring

* tests

* expose on Moran as well

* Apply suggestions from code review

Co-authored-by: James Gaboardi <[email protected]>

* cleanup

* use plot_scatter

---------

Co-authored-by: James Gaboardi <[email protected]>
  • Loading branch information
martinfleis and jGaboardi authored Jan 5, 2025
1 parent e095b21 commit e38e151
Show file tree
Hide file tree
Showing 2 changed files with 306 additions and 10 deletions.
163 changes: 162 additions & 1 deletion esda/moran.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"Levi John Wolf <[email protected]>"
)

from warnings import simplefilter
from warnings import simplefilter, warn

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -321,6 +321,37 @@ def by_col(
**stat_kws,
)

def plot_scatter(
self,
ax=None,
scatter_kwds=None,
fitline_kwds=None,
):
"""
Plot a Moran scatterplot with optional coloring for significant points.
Parameters
----------
ax : matplotlib.axes.Axes, optional
Pre-existing axes for the plot, by default None.
scatter_kwds : dict, optional
Additional keyword arguments for scatter plot, by default None.
fitline_kwds : dict, optional
Additional keyword arguments for fit line, by default None.
Returns
-------
matplotlib.axes.Axes
Axes object with the Moran scatterplot.
"""
return _scatterplot(
self,
crit_value=None,
ax=ax,
scatter_kwds=scatter_kwds,
fitline_kwds=fitline_kwds,
)


class Moran_BV: # noqa: N801
"""
Expand Down Expand Up @@ -1272,6 +1303,40 @@ def plot(self, gdf, crit_value=0.05, **kwargs):
gdf["Moran Cluster"] = self.get_cluster_labels(crit_value)
return _viz_local_moran(self, gdf, crit_value, "plot", **kwargs)

def plot_scatter(
self,
crit_value=0.05,
ax=None,
scatter_kwds=None,
fitline_kwds=None,
):
"""
Plot a Moran scatterplot with optional coloring for significant points.
Parameters
----------
crit_value : float, optional
Critical value to determine statistical significance, by default 0.05.
ax : matplotlib.axes.Axes, optional
Pre-existing axes for the plot, by default None.
scatter_kwds : dict, optional
Additional keyword arguments for scatter plot, by default None.
fitline_kwds : dict, optional
Additional keyword arguments for fit line, by default None.
Returns
-------
matplotlib.axes.Axes
Axes object with the Moran scatterplot.
"""
return _scatterplot(
self,
crit_value=crit_value,
ax=ax,
scatter_kwds=scatter_kwds,
fitline_kwds=fitline_kwds,
)


class Moran_Local_BV: # noqa: N801
"""Bivariate Local Moran Statistics.
Expand Down Expand Up @@ -1863,6 +1928,102 @@ def _get_cluster_labels(moran_local, crit_value):
return gdf["Moran Cluster"].values


def _scatterplot(
moran,
crit_value=0.05,
ax=None,
scatter_kwds=None,
fitline_kwds=None,
):
"""Generates a Moran Local or Global Scatterplot.
Parameters
----------
moran : Moran object
An instance of a Moran or Moran_Local object.
crit_value : float, optional
The critical value for significance. Default is 0.05.
ax : matplotlib.axes.Axes, optional
The axes on which to draw the plot. If None, a new figure and axes are created.
scatter_kwds : dict, optional
Additional keyword arguments to pass to the scatter plot.
fitline_kwds : dict, optional
Additional keyword arguments to pass to the fit line plot.
Returns
-------
ax : matplotlib.axes.Axes
The axes with the Moran Scatterplot.
Raises
------
ImportError
If matplotlib is not installed.
"""

try:
from matplotlib import pyplot as plt
except ImportError as err:
raise ImportError(
"matplotlib library must be installed to use the scatterplot feature"
) from err

# to set default as an empty dictionary that is later filled with defaults
if scatter_kwds is None:
scatter_kwds = dict()
if fitline_kwds is None:
fitline_kwds = dict()

if crit_value is not None:
labels = moran.get_cluster_labels(crit_value)
# TODO: allow customization of colors in here and in plot and explore
# TODO: in a way to keep them easily synced
colors5_mpl = {
"High-High": "#d7191c",
"Low-High": "#89cff0",
"Low-Low": "#2c7bb6",
"High-Low": "#fdae61",
"Insignificant": "lightgrey",
}
colors5 = [colors5_mpl[i] for i in labels] # for mpl

# define customization
scatter_kwds.setdefault("alpha", 0.6)
fitline_kwds.setdefault("alpha", 0.9)

if ax is None:
_, ax = plt.subplots()

# set labels
ax.set_xlabel("Attribute")
ax.set_ylabel("Spatial Lag")
ax.set_title("Moran Local Scatterplot")

# plot and set standards
lag = lag_spatial(moran.w, moran.z)
fit = stats.linregress(
moran.z,
lag,
)
# v- and hlines
ax.axvline(0, alpha=0.5, color="k", linestyle="--")
ax.axhline(0, alpha=0.5, color="k", linestyle="--")
if crit_value is not None:
fitline_kwds.setdefault("color", "k")
scatter_kwds.setdefault("c", colors5)
ax.plot(moran.z, fit.intercept + fit.slope * moran.z, **fitline_kwds)
ax.scatter(moran.z, lag, **scatter_kwds)
else:
scatter_kwds.setdefault("color", "#bababa")
fitline_kwds.setdefault("color", "#d6604d")
ax.plot(moran.z, fit.intercept + fit.slope * moran.z, **fitline_kwds)
ax.scatter(moran.z, lag, **scatter_kwds)

ax.set_aspect("equal")

return ax


# --------------------------------------------------------------
# Conditional Randomization Moment Estimators
# --------------------------------------------------------------
Expand Down
153 changes: 144 additions & 9 deletions esda/tests/test_moran.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,58 @@ def test_by_col(self):
np.testing.assert_allclose(sidr, 0.24772519320480135, atol=ATOL, rtol=RTOL)
np.testing.assert_allclose(pval, 0.001)

@parametrize_sac
def test_Moran_plot_scatter(self, w):
import matplotlib

matplotlib.use("Agg")

m = moran.Moran(
sac1.WHITE,
w,
)

ax = m.plot_scatter()

# test scatter
np.testing.assert_array_almost_equal(
ax.collections[0].get_facecolors(),
np.array([[0.729412, 0.729412, 0.729412, 0.6]]),
)

# test fitline
l = ax.lines[2]
x, y = l.get_data()
np.testing.assert_almost_equal(x.min(), -1.8236414387225368)
np.testing.assert_almost_equal(x.max(), 3.893056527659032)
np.testing.assert_almost_equal(y.min(), -0.7371749399524187)
np.testing.assert_almost_equal(y.max(), 1.634939204358587)
assert l.get_color() == "#d6604d"

@parametrize_sac
def test_Moran_plot_scatter_args(self, w):
import matplotlib

matplotlib.use("Agg")

m = moran.Moran(
sac1.WHITE,
w,
)

ax = m.plot_scatter(scatter_kwds=dict(color='blue'), fitline_kwds=dict(color='pink'))

# test scatter
np.testing.assert_array_almost_equal(
ax.collections[0].get_facecolors(),
np.array([[0, 0, 1, 0.6]]),
)

# test fitline
l = ax.lines[2]
assert l.get_color() == "pink"



class TestMoranRate:
def setup_method(self):
Expand Down Expand Up @@ -251,15 +303,98 @@ def test_Moran_Local_plot(self, w):
seed=SEED,
)
ax = lm.plot(sac1)
unique, counts = np.unique(ax.collections[0].get_facecolors(), axis=0, return_counts=True)
np.testing.assert_array_almost_equal(unique, np.array([
[0.17254902, 0.48235294, 0.71372549, 1.],
[0.5372549 , 0.81176471, 0.94117647, 1.],
[0.82745098, 0.82745098, 0.82745098, 1.],
[0.84313725, 0.09803922, 0.10980392, 1.],
[0.99215686, 0.68235294, 0.38039216, 1.]]
))
np.testing.assert_array_equal(counts, np.array([86,3, 298,38, 3]))
unique, counts = np.unique(
ax.collections[0].get_facecolors(), axis=0, return_counts=True
)
np.testing.assert_array_almost_equal(
unique,
np.array(
[
[0.17254902, 0.48235294, 0.71372549, 1.0],
[0.5372549, 0.81176471, 0.94117647, 1.0],
[0.82745098, 0.82745098, 0.82745098, 1.0],
[0.84313725, 0.09803922, 0.10980392, 1.0],
[0.99215686, 0.68235294, 0.38039216, 1.0],
]
),
)
np.testing.assert_array_equal(counts, np.array([86, 3, 298, 38, 3]))

@parametrize_sac
def test_Moran_Local_plot_scatter(self, w):
import matplotlib

matplotlib.use("Agg")

lm = moran.Moran_Local(
sac1.WHITE,
w,
transformation="r",
permutations=99,
keep_simulations=True,
seed=SEED,
)

ax = lm.plot_scatter()

# test scatter
unique, counts = np.unique(
ax.collections[0].get_facecolors(), axis=0, return_counts=True
)
np.testing.assert_array_almost_equal(
unique,
np.array(
[
[0.17254902, 0.48235294, 0.71372549, 0.6],
[0.5372549, 0.81176471, 0.94117647, 0.6],
[0.82745098, 0.82745098, 0.82745098, 0.6],
[0.84313725, 0.09803922, 0.10980392, 0.6],
[0.99215686, 0.68235294, 0.38039216, 0.6],
]
),
)
np.testing.assert_array_equal(counts, np.array([73, 12, 261, 52, 5]))

# test fitline
l = ax.lines[2]
x, y = l.get_data()
np.testing.assert_almost_equal(x.min(), -1.8236414387225368)
np.testing.assert_almost_equal(x.max(), 3.893056527659032)
np.testing.assert_almost_equal(y.min(), -0.7371749399524187)
np.testing.assert_almost_equal(y.max(), 1.634939204358587)
assert l.get_color() == "k"

@parametrize_sac
def test_Moran_Local_plot_scatter_args(self, w):
import matplotlib

matplotlib.use("Agg")

lm = moran.Moran_Local(
sac1.WHITE,
w,
transformation="r",
permutations=99,
keep_simulations=True,
seed=SEED,
)

ax = lm.plot_scatter(
crit_value=None,
scatter_kwds={"s": 10},
fitline_kwds={"linewidth": 4},
)
# test scatter
np.testing.assert_array_almost_equal(
ax.collections[0].get_facecolors(),
np.array([[0.729412, 0.729412, 0.729412, 0.6]]),
)
assert ax.collections[0].get_sizes()[0] == 10

# test fitline
l = ax.lines[2]
assert l.get_color() == "#d6604d"
assert l.get_linewidth() == 4.0

@parametrize_desmith
def test_Moran_Local_parallel(self, w):
Expand Down

0 comments on commit e38e151

Please sign in to comment.