Skip to content

Commit

Permalink
Merge pull request #609 from guillaume-vignal/feature/fix_title_height
Browse files Browse the repository at this point in the history
Add Customizable Graph Size, Improve Title Alignment, and Fix Color Palette Selection in Shapash Explainability Quality Graphs
  • Loading branch information
guillaume-vignal authored Oct 28, 2024
2 parents 8ebdb85 + a15bd74 commit 455fdc3
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 225 deletions.
58 changes: 46 additions & 12 deletions shapash/explainer/consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
class Consistency:
"""Consistency class"""

def __init__(self):
self._palette_name = list(colors_loading().keys())[0]
def __init__(self, palette_name="default"):
self._palette_name = palette_name
self._style_dict = define_style(select_palette(colors_loading(), self._palette_name))

def tuning_colorscale(self, values):
Expand Down Expand Up @@ -454,7 +454,15 @@ def plot_examples(self, method_1, method_2, l2, index, backend_name_1, backend_n
return fig

def pairwise_consistency_plot(
self, methods, selection=None, max_features=10, max_points=100, file_name=None, auto_open=False
self,
methods,
selection=None,
max_features=10,
max_points=100,
file_name=None,
auto_open=False,
width=1000,
height="auto",
):
"""The Pairwise_Consistency_plot compares the difference of 2 explainability methods across each feature and each data point,
and plots the distribution of those differences.
Expand All @@ -480,6 +488,10 @@ def pairwise_consistency_plot(
Specify the save path of html files. If it is not provided, no file will be saved.
auto_open: bool
open automatically the plot, by default False
height : str or int, optional
Height of the figure. Default is 'auto'.
width : int, optional
Width of the figure. Default is 1000.
Returns
Expand Down Expand Up @@ -520,11 +532,15 @@ def pairwise_consistency_plot(
mean_contributions = np.mean(np.abs(pd.concat(weights)), axis=0)
top_features = np.flip(mean_contributions.sort_values(ascending=False)[:max_features].keys())

fig = self.plot_pairwise_consistency(weights, x, top_features, methods, file_name, auto_open)
fig = self.plot_pairwise_consistency(
weights, x, top_features, methods, file_name, auto_open, width=width, height=height
)

return fig

def plot_pairwise_consistency(self, weights, x, top_features, methods, file_name, auto_open):
def plot_pairwise_consistency(
self, weights, x, top_features, methods, file_name, auto_open, width=1000, height="auto"
):
"""Plot the main graph displaying distances between methods across each feature and data point
Parameters
Expand All @@ -541,6 +557,10 @@ def plot_pairwise_consistency(self, weights, x, top_features, methods, file_name
Specify the save path of html files. If it is not provided, no file will be saved.
auto_open: bool
open automatically the plot
height : str or int, optional
Height of the figure. Default is 'auto'.
width : int, optional
Width of the figure. Default is 1000.
Returns
-------
Expand All @@ -555,8 +575,7 @@ def plot_pairwise_consistency(self, weights, x, top_features, methods, file_name
x = encoder.transform(x)

xaxis_title = (
"Difference of contributions between the 2 methods"
+ f"<span style='font-size: 12px;'><br />{methods[0]} - {methods[1]}</span>"
"<br>Difference of contributions between the 2 methods" + f"<br><sup>{methods[0]} - {methods[1]}</sup>"
)
yaxis_title = (
"Top features<span style='font-size: 12px;'><br />(Ordered by mean of absolute contributions)</span>"
Expand Down Expand Up @@ -647,11 +666,15 @@ def plot_pairwise_consistency(self, weights, x, top_features, methods, file_name
yaxis_title=yaxis_title,
file_name=file_name,
auto_open=auto_open,
height=height,
width=width,
)

return fig

def _update_pairwise_consistency_fig(self, fig, top_features, xaxis_title, yaxis_title, file_name, auto_open):
def _update_pairwise_consistency_fig(
self, fig, top_features, xaxis_title, yaxis_title, file_name, auto_open, height="auto", width=1000
):
"""Function used for the pairwise_consistency_plot to update the layout of the plotly figure.
Parameters
Expand All @@ -668,11 +691,19 @@ def _update_pairwise_consistency_fig(self, fig, top_features, xaxis_title, yaxis
Specify the save path of html files. If it is not provided, no file will be saved.
auto_open: bool
open automatically the plot
height : str or int, optional
Height of the figure. Default is 'auto'.
width : int, optional
Width of the figure. Default is 1000.
Returns
-------
None
"""
height = max(500, 40 * len(top_features))
title = "Pairwise comparison of Consistency:"
title += "<span style='font-size: 16px;'>\
<br />How are differences in contributions distributed across features?</span>"
if height == "auto":
height = max(500, 40 * len(top_features) + 300)
title = "<br>Pairwise comparison of Consistency:"
title += "<br><sup>How are differences in contributions distributed across features?</sup>"
dict_t = self._style_dict["dict_title_stability"] | {"text": title, "y": adjust_title_height(height)}
dict_xaxis = self._style_dict["dict_xaxis"] | {"text": xaxis_title}
dict_yaxis = self._style_dict["dict_yaxis"] | {"text": yaxis_title}
Expand All @@ -681,12 +712,15 @@ def _update_pairwise_consistency_fig(self, fig, top_features, xaxis_title, yaxis
fig.layout.yaxis2.update(showticklabels=False)
fig.update_layout(
template="none",
autosize=False,
title=dict_t,
xaxis_title=dict_xaxis,
yaxis_title=dict_yaxis,
yaxis=dict(range=[-0.7, len(top_features) - 0.3]),
yaxis2=dict(range=[-0.7, len(top_features) - 0.3]),
height=height,
width=width,
margin={"l": 150, "r": 20, "t": 95, "b": 70},
)

fig.update_yaxes(automargin=True, zeroline=False)
Expand Down
54 changes: 48 additions & 6 deletions shapash/explainer/smart_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,7 +1428,7 @@ def correlations_plot(

return fig

def local_neighbors_plot(self, index, max_features=10, file_name=None, auto_open=False):
def local_neighbors_plot(self, index, max_features=10, file_name=None, auto_open=False, height="auto", width=900):
"""
The Local_neighbors_plot has the main objective of increasing confidence \
in interpreting the contribution values of a selected instance.
Expand All @@ -1450,6 +1450,7 @@ def local_neighbors_plot(self, index, max_features=10, file_name=None, auto_open
* For classification:
.. math::
distance = |output_{allFeatures} - output_{currentFeatures}|
Parameters
----------
index: int
Expand All @@ -1460,6 +1461,11 @@ def local_neighbors_plot(self, index, max_features=10, file_name=None, auto_open
Specify the save path of html files. If it is not provided, no file will be saved, by default None
auto_open: bool, optional
open automatically the plot, by default False
height : str or int, optional
Height of the figure. Default is 'auto'.
width : int, optional
Width of the figure. Default is 900.
Returns
-------
fig
Expand Down Expand Up @@ -1512,15 +1518,18 @@ def ordinal(n):
]
)

height = max(500, 11 * g_df.shape[0] * g_df.shape[1])
title = f"Comparing local explanations in a neighborhood - Id: <b>{index}</b>"
title += "<span style='font-size: 16px;'><br />How similar are explanations for closeby neighbours?</span>"
if height == "auto":
height = max(500, 11 * g_df.shape[0] * g_df.shape[1])
title = f"<br>Comparing local explanations in a neighborhood - Id: <b>{index}</b>"
title += "<br><sup>How similar are explanations for closeby neighbours?</sup>"
dict_t = self._style_dict["dict_title_stability"] | {"text": title, "y": adjust_title_height(height)}
dict_xaxis = self._style_dict["dict_xaxis"] | {"text": "Normalized contribution values"}
dict_yaxis = self._style_dict["dict_yaxis"] | {"text": ""}

fig.update_layout(
template="none",
autosize=False,
width=width,
title=dict_t,
xaxis_title=dict_xaxis,
yaxis_title=dict_yaxis,
Expand All @@ -1529,6 +1538,7 @@ def ordinal(n):
height=height,
legend={"traceorder": "reversed"},
xaxis={"side": "bottom"},
margin={"l": 150, "r": 20, "t": 95, "b": 70},
)

fig.update_yaxes(automargin=True)
Expand All @@ -1548,6 +1558,8 @@ def stability_plot(
distribution="none",
file_name=None,
auto_open=False,
height="auto",
width=900,
):
"""
The Stability_plot has the main objective of increasing confidence in contribution values, \
Expand Down Expand Up @@ -1588,6 +1600,11 @@ def stability_plot(
Specify the save path of html files. If it is not provided, no file will be saved, by default None
auto_open: bool, optional
open automatically the plot, by default False
height: int or 'auto'
Plotly figure - layout height
width: int
Plotly figure - layout width
Returns
-------
If single instance:
Expand Down Expand Up @@ -1641,6 +1658,8 @@ def stability_plot(
auto_open,
self._style_dict["init_contrib_colorscale"],
self._style_dict,
height=height,
width=width,
)

# Plot 2 : Show distribution of variability
Expand All @@ -1665,12 +1684,23 @@ def stability_plot(
auto_open,
self._style_dict["init_contrib_colorscale"],
self._style_dict,
height=height,
width=width,
)

return fig

def compacity_plot(
self, selection=None, max_points=2000, force=False, approx=0.9, nb_features=5, file_name=None, auto_open=False
self,
selection=None,
max_points=2000,
force=False,
approx=0.9,
nb_features=5,
file_name=None,
auto_open=False,
height=600,
width=900,
):
"""
The Compacity_plot has the main objective of determining if a small subset of features
Expand Down Expand Up @@ -1705,6 +1735,10 @@ def compacity_plot(
Specify the save path of html files. If it is not provided, no file will be saved, by default None
auto_open: bool, optional
open automatically the plot, by default False
height: int, optional
height of the plot, by default 600
width: int, optional
width of the plot, by default 900
"""
# Sampling
if selection is None:
Expand Down Expand Up @@ -1735,7 +1769,15 @@ def compacity_plot(

# Plot generation
fig = plot_compacity(
features_needed, distance_reached, self._style_dict, approx, nb_features, file_name, auto_open
features_needed,
distance_reached,
self._style_dict,
approx,
nb_features,
file_name,
auto_open,
height,
width,
)

return fig
Expand Down
26 changes: 19 additions & 7 deletions shapash/plots/plot_compacity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@


def plot_compacity(
features_needed, distance_reached, style_dict, approx=0.9, nb_features=5, file_name=None, auto_open=False
features_needed,
distance_reached,
style_dict,
approx=0.9,
nb_features=5,
file_name=None,
auto_open=False,
height=600,
width=900,
):
"""
The Compacity_plot has the main objective of determining if a small subset of features \
Expand Down Expand Up @@ -39,6 +47,10 @@ def plot_compacity(
Specify the save path of html files. If it is not provided, no file will be saved, by default None
auto_open: bool, optional
open automatically the plot, by default False
height: int, optional
height of the plot, by default 600
width: int, optional
width of the plot, by default 900
"""

# Make plots
Expand Down Expand Up @@ -102,19 +114,19 @@ def plot_compacity(
title = style_dict["dict_yaxis"] | {"text": "Cumulative distribution over<br>dataset's instances (%)"}
fig.update_yaxes(title=title, row=1, col=2)

title = "Compacity of explanations:"
title += (
"<span style='font-size: 16px;'><br />How many variables are enough to produce accurate explanations?</span>"
)
title = "<br>Compacity of explanations:"
title += "<br><sup>How many variables are enough to produce accurate explanations?</sup>"
dict_t = style_dict["dict_title_stability"] | {"text": title, "y": adjust_title_height()}

fig.update_layout(
template="none",
autosize=False,
height=height,
width=width,
title=dict_t,
title_y=0.8,
hovermode="closest",
margin={"t": 150},
showlegend=False,
margin={"l": 150, "r": 20, "t": 150, "b": 70},
)

if file_name is not None:
Expand Down
Loading

0 comments on commit 455fdc3

Please sign in to comment.