diff --git a/shapash/explainer/consistency.py b/shapash/explainer/consistency.py
index bd053f11..ac3bfa4a 100644
--- a/shapash/explainer/consistency.py
+++ b/shapash/explainer/consistency.py
@@ -16,8 +16,8 @@
class Consistency:
"""Consistency class"""
- def __init__(self):
- self._palette_name = list(colors_loading().keys())[0]
+ def __init__(self, palette_name="default"):
+ self._palette_name = palette_name
self._style_dict = define_style(select_palette(colors_loading(), self._palette_name))
def tuning_colorscale(self, values):
@@ -454,7 +454,15 @@ def plot_examples(self, method_1, method_2, l2, index, backend_name_1, backend_n
return fig
def pairwise_consistency_plot(
- self, methods, selection=None, max_features=10, max_points=100, file_name=None, auto_open=False
+ self,
+ methods,
+ selection=None,
+ max_features=10,
+ max_points=100,
+ file_name=None,
+ auto_open=False,
+ width=1000,
+ height="auto",
):
"""The Pairwise_Consistency_plot compares the difference of 2 explainability methods across each feature and each data point,
and plots the distribution of those differences.
@@ -480,6 +488,10 @@ def pairwise_consistency_plot(
Specify the save path of html files. If it is not provided, no file will be saved.
auto_open: bool
open automatically the plot, by default False
+ height : str or int, optional
+ Height of the figure. Default is 'auto'.
+ width : int, optional
+ Width of the figure. Default is 1000.
Returns
@@ -520,11 +532,15 @@ def pairwise_consistency_plot(
mean_contributions = np.mean(np.abs(pd.concat(weights)), axis=0)
top_features = np.flip(mean_contributions.sort_values(ascending=False)[:max_features].keys())
- fig = self.plot_pairwise_consistency(weights, x, top_features, methods, file_name, auto_open)
+ fig = self.plot_pairwise_consistency(
+ weights, x, top_features, methods, file_name, auto_open, width=width, height=height
+ )
return fig
- def plot_pairwise_consistency(self, weights, x, top_features, methods, file_name, auto_open):
+ def plot_pairwise_consistency(
+ self, weights, x, top_features, methods, file_name, auto_open, width=1000, height="auto"
+ ):
"""Plot the main graph displaying distances between methods across each feature and data point
Parameters
@@ -541,6 +557,10 @@ def plot_pairwise_consistency(self, weights, x, top_features, methods, file_name
Specify the save path of html files. If it is not provided, no file will be saved.
auto_open: bool
open automatically the plot
+ height : str or int, optional
+ Height of the figure. Default is 'auto'.
+ width : int, optional
+ Width of the figure. Default is 1000.
Returns
-------
@@ -555,8 +575,7 @@ def plot_pairwise_consistency(self, weights, x, top_features, methods, file_name
x = encoder.transform(x)
xaxis_title = (
- "Difference of contributions between the 2 methods"
- + f"
{methods[0]} - {methods[1]}"
+ "
Difference of contributions between the 2 methods" + f"
{methods[0]} - {methods[1]}"
)
yaxis_title = (
"Top features
(Ordered by mean of absolute contributions)"
@@ -647,11 +666,15 @@ def plot_pairwise_consistency(self, weights, x, top_features, methods, file_name
yaxis_title=yaxis_title,
file_name=file_name,
auto_open=auto_open,
+ height=height,
+ width=width,
)
return fig
- def _update_pairwise_consistency_fig(self, fig, top_features, xaxis_title, yaxis_title, file_name, auto_open):
+ def _update_pairwise_consistency_fig(
+ self, fig, top_features, xaxis_title, yaxis_title, file_name, auto_open, height="auto", width=1000
+ ):
"""Function used for the pairwise_consistency_plot to update the layout of the plotly figure.
Parameters
@@ -668,11 +691,19 @@ def _update_pairwise_consistency_fig(self, fig, top_features, xaxis_title, yaxis
Specify the save path of html files. If it is not provided, no file will be saved.
auto_open: bool
open automatically the plot
+ height : str or int, optional
+ Height of the figure. Default is 'auto'.
+ width : int, optional
+ Width of the figure. Default is 1000.
+
+ Returns
+ -------
+ None
"""
- height = max(500, 40 * len(top_features))
- title = "Pairwise comparison of Consistency:"
- title += "\
-
How are differences in contributions distributed across features?"
+ if height == "auto":
+ height = max(500, 40 * len(top_features) + 300)
+ title = "
Pairwise comparison of Consistency:"
+ title += "
How are differences in contributions distributed across features?"
dict_t = self._style_dict["dict_title_stability"] | {"text": title, "y": adjust_title_height(height)}
dict_xaxis = self._style_dict["dict_xaxis"] | {"text": xaxis_title}
dict_yaxis = self._style_dict["dict_yaxis"] | {"text": yaxis_title}
@@ -681,12 +712,15 @@ def _update_pairwise_consistency_fig(self, fig, top_features, xaxis_title, yaxis
fig.layout.yaxis2.update(showticklabels=False)
fig.update_layout(
template="none",
+ autosize=False,
title=dict_t,
xaxis_title=dict_xaxis,
yaxis_title=dict_yaxis,
yaxis=dict(range=[-0.7, len(top_features) - 0.3]),
yaxis2=dict(range=[-0.7, len(top_features) - 0.3]),
height=height,
+ width=width,
+ margin={"l": 150, "r": 20, "t": 95, "b": 70},
)
fig.update_yaxes(automargin=True, zeroline=False)
diff --git a/shapash/explainer/smart_plotter.py b/shapash/explainer/smart_plotter.py
index 5c568b71..f135edb6 100644
--- a/shapash/explainer/smart_plotter.py
+++ b/shapash/explainer/smart_plotter.py
@@ -1428,7 +1428,7 @@ def correlations_plot(
return fig
- def local_neighbors_plot(self, index, max_features=10, file_name=None, auto_open=False):
+ def local_neighbors_plot(self, index, max_features=10, file_name=None, auto_open=False, height="auto", width=900):
"""
The Local_neighbors_plot has the main objective of increasing confidence \
in interpreting the contribution values of a selected instance.
@@ -1450,6 +1450,7 @@ def local_neighbors_plot(self, index, max_features=10, file_name=None, auto_open
* For classification:
.. math::
distance = |output_{allFeatures} - output_{currentFeatures}|
+
Parameters
----------
index: int
@@ -1460,6 +1461,11 @@ def local_neighbors_plot(self, index, max_features=10, file_name=None, auto_open
Specify the save path of html files. If it is not provided, no file will be saved, by default None
auto_open: bool, optional
open automatically the plot, by default False
+ height : str or int, optional
+ Height of the figure. Default is 'auto'.
+ width : int, optional
+ Width of the figure. Default is 900.
+
Returns
-------
fig
@@ -1512,15 +1518,18 @@ def ordinal(n):
]
)
- height = max(500, 11 * g_df.shape[0] * g_df.shape[1])
- title = f"Comparing local explanations in a neighborhood - Id: {index}"
- title += "
How similar are explanations for closeby neighbours?"
+ if height == "auto":
+ height = max(500, 11 * g_df.shape[0] * g_df.shape[1])
+ title = f"
Comparing local explanations in a neighborhood - Id: {index}"
+ title += "
How similar are explanations for closeby neighbours?"
dict_t = self._style_dict["dict_title_stability"] | {"text": title, "y": adjust_title_height(height)}
dict_xaxis = self._style_dict["dict_xaxis"] | {"text": "Normalized contribution values"}
dict_yaxis = self._style_dict["dict_yaxis"] | {"text": ""}
fig.update_layout(
template="none",
+ autosize=False,
+ width=width,
title=dict_t,
xaxis_title=dict_xaxis,
yaxis_title=dict_yaxis,
@@ -1529,6 +1538,7 @@ def ordinal(n):
height=height,
legend={"traceorder": "reversed"},
xaxis={"side": "bottom"},
+ margin={"l": 150, "r": 20, "t": 95, "b": 70},
)
fig.update_yaxes(automargin=True)
@@ -1548,6 +1558,8 @@ def stability_plot(
distribution="none",
file_name=None,
auto_open=False,
+ height="auto",
+ width=900,
):
"""
The Stability_plot has the main objective of increasing confidence in contribution values, \
@@ -1588,6 +1600,11 @@ def stability_plot(
Specify the save path of html files. If it is not provided, no file will be saved, by default None
auto_open: bool, optional
open automatically the plot, by default False
+ height: int or 'auto'
+ Plotly figure - layout height
+ width: int
+ Plotly figure - layout width
+
Returns
-------
If single instance:
@@ -1641,6 +1658,8 @@ def stability_plot(
auto_open,
self._style_dict["init_contrib_colorscale"],
self._style_dict,
+ height=height,
+ width=width,
)
# Plot 2 : Show distribution of variability
@@ -1665,12 +1684,23 @@ def stability_plot(
auto_open,
self._style_dict["init_contrib_colorscale"],
self._style_dict,
+ height=height,
+ width=width,
)
return fig
def compacity_plot(
- self, selection=None, max_points=2000, force=False, approx=0.9, nb_features=5, file_name=None, auto_open=False
+ self,
+ selection=None,
+ max_points=2000,
+ force=False,
+ approx=0.9,
+ nb_features=5,
+ file_name=None,
+ auto_open=False,
+ height=600,
+ width=900,
):
"""
The Compacity_plot has the main objective of determining if a small subset of features
@@ -1705,6 +1735,10 @@ def compacity_plot(
Specify the save path of html files. If it is not provided, no file will be saved, by default None
auto_open: bool, optional
open automatically the plot, by default False
+ height: int, optional
+ height of the plot, by default 600
+ width: int, optional
+ width of the plot, by default 900
"""
# Sampling
if selection is None:
@@ -1735,7 +1769,15 @@ def compacity_plot(
# Plot generation
fig = plot_compacity(
- features_needed, distance_reached, self._style_dict, approx, nb_features, file_name, auto_open
+ features_needed,
+ distance_reached,
+ self._style_dict,
+ approx,
+ nb_features,
+ file_name,
+ auto_open,
+ height,
+ width,
)
return fig
diff --git a/shapash/plots/plot_compacity.py b/shapash/plots/plot_compacity.py
index f5e3a9d2..7ad34f91 100644
--- a/shapash/plots/plot_compacity.py
+++ b/shapash/plots/plot_compacity.py
@@ -6,7 +6,15 @@
def plot_compacity(
- features_needed, distance_reached, style_dict, approx=0.9, nb_features=5, file_name=None, auto_open=False
+ features_needed,
+ distance_reached,
+ style_dict,
+ approx=0.9,
+ nb_features=5,
+ file_name=None,
+ auto_open=False,
+ height=600,
+ width=900,
):
"""
The Compacity_plot has the main objective of determining if a small subset of features \
@@ -39,6 +47,10 @@ def plot_compacity(
Specify the save path of html files. If it is not provided, no file will be saved, by default None
auto_open: bool, optional
open automatically the plot, by default False
+ height: int, optional
+ height of the plot, by default 600
+ width: int, optional
+ width of the plot, by default 900
"""
# Make plots
@@ -102,19 +114,19 @@ def plot_compacity(
title = style_dict["dict_yaxis"] | {"text": "Cumulative distribution over
dataset's instances (%)"}
fig.update_yaxes(title=title, row=1, col=2)
- title = "Compacity of explanations:"
- title += (
- "
How many variables are enough to produce accurate explanations?"
- )
+ title = "
Compacity of explanations:"
+ title += "
How many variables are enough to produce accurate explanations?"
dict_t = style_dict["dict_title_stability"] | {"text": title, "y": adjust_title_height()}
fig.update_layout(
template="none",
+ autosize=False,
+ height=height,
+ width=width,
title=dict_t,
- title_y=0.8,
hovermode="closest",
- margin={"t": 150},
showlegend=False,
+ margin={"l": 150, "r": 20, "t": 150, "b": 70},
)
if file_name is not None:
diff --git a/shapash/plots/plot_stability.py b/shapash/plots/plot_stability.py
index 16dac1ff..0d482135 100644
--- a/shapash/plots/plot_stability.py
+++ b/shapash/plots/plot_stability.py
@@ -7,7 +7,17 @@
def plot_stability_distribution(
- variability, plot_type, mean_amplitude, dataset, column_names, file_name, auto_open, init_colorscale, style_dict
+ variability,
+ plot_type,
+ mean_amplitude,
+ dataset,
+ column_names,
+ file_name,
+ auto_open,
+ init_colorscale,
+ style_dict,
+ height="auto",
+ width=900,
):
"""
Generates and displays a stability distribution plot for feature variability using either a boxplot or violin plot.
@@ -43,6 +53,10 @@ def plot_stability_distribution(
style_dict : dict
A dictionary specifying the various style options such as font size, color, and other aesthetic parameters
for the plot.
+ height: int or 'auto'
+ Plotly figure - layout height
+ width: int
+ Plotly figure - layout width
Returns
-------
@@ -70,7 +84,10 @@ def plot_stability_distribution(
color_list = mean_amplitude_normalized.tolist()
color_list.sort()
color_list = [next(pair[1] for pair in col_scale if x <= pair[0]) for x in color_list]
- height_value = max(500, 40 * dataset.shape[1] if dataset.shape[1] < 100 else 13 * dataset.shape[1])
+ if height == "auto":
+ height_value = max(500, 40 * dataset.shape[1] if dataset.shape[1] < 100 else 13 * dataset.shape[1])
+ else:
+ height_value = height
xaxis_title = "Normalized local contribution value variability"
yaxis_title = ""
@@ -134,12 +151,15 @@ def plot_stability_distribution(
file_name=file_name,
auto_open=auto_open,
height=height_value,
+ width=width,
)
return fig
-def _update_stability_fig(fig, x_barlen, y_bar, style_dict, xaxis_title, yaxis_title, file_name, auto_open, height=500):
+def _update_stability_fig(
+ fig, x_barlen, y_bar, style_dict, xaxis_title, yaxis_title, file_name, auto_open, height=500, width=900
+):
"""
Function used for the `plot_stability_distribution` and `plot_amplitude_vs_stability`
to update the layout of the plotly figure.
@@ -169,8 +189,8 @@ def _update_stability_fig(fig, x_barlen, y_bar, style_dict, xaxis_title, yaxis_t
-------
go.Figure
"""
- title = "Importance & Local Stability of explanations:"
- title += "
How similar are explanations for closeby neighbours?"
+ title = "
Importance & Local Stability of explanations:"
+ title += "
How similar are explanations for closeby neighbours?"
dict_t = style_dict["dict_title_stability"] | {"text": title, "y": adjust_title_height(height)}
dict_xaxis = style_dict["dict_xaxis"] | {"text": xaxis_title}
@@ -200,12 +220,15 @@ def _update_stability_fig(fig, x_barlen, y_bar, style_dict, xaxis_title, yaxis_t
fig.update_layout(
template="none",
+ autosize=False,
title=dict_t,
xaxis_title=dict_xaxis,
yaxis_title=dict_yaxis,
coloraxis_showscale=False,
hovermode="closest",
height=height,
+ width=width,
+ margin={"l": 150, "r": 20, "t": 95, "b": 70},
)
fig.update_yaxes(automargin=True)
@@ -216,7 +239,15 @@ def _update_stability_fig(fig, x_barlen, y_bar, style_dict, xaxis_title, yaxis_t
def plot_amplitude_vs_stability(
- mean_variability, mean_amplitude, column_names, file_name, auto_open, col_scale, style_dict
+ mean_variability,
+ mean_amplitude,
+ column_names,
+ file_name,
+ auto_open,
+ col_scale,
+ style_dict,
+ height="auto",
+ width=900,
):
"""
Generates and displays a scatter plot showing the relationship between feature variability and importance.
@@ -247,6 +278,10 @@ def plot_amplitude_vs_stability(
style_dict : dict
A dictionary specifying various style options such as font size, axis formatting, and other aesthetic
properties for the plot.
+ height: int
+ Plotly figure - layout height
+ width: int
+ Plotly figure - layout width
Returns
-------
@@ -262,6 +297,8 @@ def plot_amplitude_vs_stability(
- The function can optionally save the plot as an HTML file for further exploration.
"""
+ if height == "auto":
+ height = 500
xaxis_title = (
"Variability of the Normalized Local Contribution Values"
+ "
(standard deviation / mean)"
@@ -301,5 +338,7 @@ def plot_amplitude_vs_stability(
yaxis_title=yaxis_title,
file_name=file_name,
auto_open=auto_open,
+ height=height,
+ width=width,
)
return fig
diff --git a/tutorial/explainability_quality/tuto-quality01-Builing-confidence-explainability.ipynb b/tutorial/explainability_quality/tuto-quality01-Builing-confidence-explainability.ipynb
index 95cda81b..1dad7fae 100644
--- a/tutorial/explainability_quality/tuto-quality01-Builing-confidence-explainability.ipynb
+++ b/tutorial/explainability_quality/tuto-quality01-Builing-confidence-explainability.ipynb
@@ -28,7 +28,6 @@
"outputs": [],
"source": [
"import pandas as pd\n",
- "from category_encoders import OrdinalEncoder\n",
"from sklearn.ensemble import ExtraTreesClassifier\n",
"from sklearn.model_selection import train_test_split"
]
@@ -207,7 +206,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Load Titanic data"
+ "Encode Titanic data"
]
},
{
@@ -241,11 +240,13 @@
"metadata": {},
"outputs": [],
"source": [
+ "n=200\n",
+ "\n",
"Xtrain, Xtest, ytrain, ytest = train_test_split(X_df, y_df, train_size=0.75, random_state=7)\n",
"\n",
"# Subsample\n",
- "Xtrain = Xtrain[:50].reset_index(drop=True)\n",
- "ytrain = ytrain[:50].reset_index(drop=True)"
+ "Xtrain = Xtrain[:n].reset_index(drop=True)\n",
+ "ytrain = ytrain[:n].reset_index(drop=True)"
]
},
{
@@ -254,7 +255,7 @@
"metadata": {},
"outputs": [],
"source": [
- "clf = ExtraTreesClassifier(n_estimators=200).fit(Xtrain, ytrain)"
+ "clf = ExtraTreesClassifier(n_estimators=200).fit(Xtrain, ytrain.iloc[:,0])"
]
},
{
@@ -272,248 +273,153 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "#### First, we need to instantiate and compile the Consistency object"
+ "#### Use pre-computed contributions\n",
+ "\n",
+ "First calculate contributions beforehand and use those ones in the metric using the _contributions_ argument.\n",
+ "\n",
+ "The provided contributions must be given in a dictionary format, where the key is the method name, and the value is a pandas DataFrame with the contributions.\n",
+ "\n",
+ "Let's for example calculate contributions separately:"
]
},
{
- "cell_type": "code",
- "execution_count": 7,
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [],
"source": [
- "from shapash.explainer.consistency import Consistency"
+ "#### Then, we need to instantiate and compile the Consistency object"
]
},
{
"cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "WARNING:root:No train set passed. We recommend to pass the x_train parameter in order to avoid errors.\n",
- " 27%|██▋ | 54/200 [00:00<00:00, 266.79it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Backend: Shap TreeExplainer\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 200/200 [00:00<00:00, 277.31it/s]\n",
- " 0%| | 0/8 [00:00, ?it/s]"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Backend: ACV\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " 38%|███▊ | 3/8 [00:05<00:09, 1.95s/it]\n",
- "WARNING:root:No train set passed. We recommend to pass the x_train parameter in order to avoid errors.\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Backend: LIME\n"
- ]
- }
- ],
- "source": [
- "cns = Consistency()\n",
- "cns.compile(x=Xtrain, # Dataset for which we need explanations\n",
- " model=clf, # Model to explain\n",
- " preprocessing=encoder, # Optional\n",
- " )"
- ]
- },
- {
- "cell_type": "markdown",
+ "execution_count": 7,
"metadata": {},
+ "outputs": [],
"source": [
- "#### We can now display the consistency plot:"
+ "from shapash.explainer.consistency import Consistency\n",
+ "import shap"
]
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
- "image/png": "",
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c5bf988ead634b97ba9f4fa597832722",
+ "version_major": 2,
+ "version_minor": 0
+ },
"text/plain": [
- "