Skip to content

Commit

Permalink
Merge pull request #564 from guillaume-vignal/feature/shapash_report_…
Browse files Browse the repository at this point in the history
…improvment

Feature/shapash report improvment
  • Loading branch information
guillaume-vignal authored Jul 4, 2024
2 parents 12c4f3e + 5d3edbd commit 62703a2
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 10 deletions.
8 changes: 8 additions & 0 deletions shapash/explainer/smart_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,8 @@ def generate_report(
working_dir=None,
notebook_path=None,
kernel_name=None,
max_points=200,
nb_top_interactions=5,
):
"""
This method will generate an HTML report containing different information about the project.
Expand Down Expand Up @@ -1233,6 +1235,10 @@ def generate_report(
Name of the kernel used to generate the report. This parameter can be usefull if
you have multiple jupyter kernels and that the method does not use the right kernel
by default.
max_points : int, optional
number of maximum points in the contribution plot
nb_top_interactions : int
Number of top interactions to display.
Examples
--------
>>> xpl.generate_report(
Expand Down Expand Up @@ -1284,6 +1290,8 @@ def generate_report(
title_story=title_story,
title_description=title_description,
metrics=metrics,
max_points=max_points,
nb_top_interactions=nb_top_interactions,
),
notebook_path=notebook_path,
kernel_name=kernel_name,
Expand Down
32 changes: 27 additions & 5 deletions shapash/explainer/smart_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2639,6 +2639,7 @@ def generate_title_dict(col_name1, col_name2, addnote):
def correlations(
self,
df=None,
optimized=False,
max_features=20,
features_to_hide=None,
facet_col=None,
Expand All @@ -2658,6 +2659,9 @@ def correlations(
----------
df : pd.DataFrame, optional
DataFrame for which we want to compute correlations. Will use x_init by default.
optimized : boolean, optional
True if we want to potentially accelerate the computation of the correlation matrix by reducing the
lenght of the data and the number of modalties per columns.
max_features : int (default: 10)
Max number of features to show on the matrix.
features_to_hide : list (optional)
Expand Down Expand Up @@ -2731,7 +2735,17 @@ def cluster_corr(corr, degree, inplace=False):

if df is None:
# Use x_init by default
df = self.explainer.x_init
df = self.explainer.x_init.copy()

if optimized:
categorical_columns = df.select_dtypes(include=["object", "category"]).columns

for col in categorical_columns:
top_categories = df[col].value_counts().nlargest(200).index
df[col] = df[col].where(df[col].isin(top_categories), other="Other")

if len(df) > 10000:
df = df.sample(n=10000, random_state=1)

if facet_col:
features_to_hide += [facet_col]
Expand All @@ -2758,12 +2772,16 @@ def cluster_corr(corr, degree, inplace=False):
top_features = compute_top_correlations_features(corr=corr, max_features=max_features)
corr = cluster_corr(corr.loc[top_features, top_features], degree=degree)
list_features = list(corr.columns)
k = 6
list_features_shorten = [
x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k else x for x in list_features
]

fig.add_trace(
go.Heatmap(
z=corr.loc[list_features, list_features].round(decimals).values,
x=list_features,
y=list_features,
x=list_features_shorten,
y=list_features_shorten,
coloraxis="coloraxis",
text=[
[
Expand All @@ -2784,12 +2802,16 @@ def cluster_corr(corr, degree, inplace=False):
top_features = compute_top_correlations_features(corr=corr, max_features=max_features)
corr = cluster_corr(corr.loc[top_features, top_features], degree=degree)
list_features = [col for col in corr.columns if col in top_features]
k = 6
list_features_shorten = [
x.replace(x[k + k // 2 : -k + k // 2], "...") if len(x) > 2 * k else x for x in list_features
]

fig = go.Figure(
go.Heatmap(
z=corr.loc[list_features, list_features].round(decimals).values,
x=list_features,
y=list_features,
x=list_features_shorten,
y=list_features_shorten,
coloraxis="coloraxis",
text=[
[
Expand Down
24 changes: 21 additions & 3 deletions shapash/report/html/explainability.html
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ <h3 data-toc-skip>Global feature importance plot</h3>
{% for label in labels %}
<div class="row" id="explain-all-div-{{ label['name'] }}" style="{% if label['index'] != 0 %}display:none;{% endif %} margin-right:5px; margin-left:5px; margin-top:5px;">
{{ label['feature_importance_plot'] }}
{% with menuId='dropdownMenu2', menuText='Feature', values=label['features'], menuDivVisible='explain-contrib-'~label['index'] %}
{% include "dropdown.html" %}
{% endwith %}
</div>
{% endfor %}
<h3>Features contribution plots</h3>
{% for label in labels %}
<div class="row" id="explain-all-div-{{ label['name'] }}-2" style="{% if label['index'] != 0 %}display:none;{% endif %} margin-right:5px; margin-left:5px; margin-top:5px;">
{% with menuId='dropdownMenu2', menuText='Feature', values=label['features'], menuDivVisible='explain-contrib-'~label['index'] %}
{% include "dropdown.html" %}
{% endwith %}
{% for col in label['features'] %}
<div class="row" id="explain-contrib-{{ label['index'] }}-div-{{ col['name'] }}" style="{% if col['feature_index'] != 0 %}display:none;{% endif %} margin-right:5px; margin-left:5px;">
<h4>{{ col['name'] }} - {{ col['type'] }}</h4>
Expand All @@ -28,3 +28,21 @@ <h4>{{ col['name'] }} - {{ col['type'] }}</h4>
{% endfor %}
</div>
{% endfor %}
<h3>Features Top Interaction plots</h3>
{% for label in labels %}
<div class="row" id="explain-all-div-interaction-{{ label['name'] }}-2" style="{% if label['index'] != 0 %}display:none;{% endif %} margin-right:5px; margin-left:5px; margin-top:5px;">
{% with menuId='dropdownMenu3', menuText='Interactions', values=label['features_interaction'], menuDivVisible='explain-contrib-interaction-'~label['index'] %}
{% include "dropdown.html" %}
{% endwith %}
{% for col in label['features_interaction'] %}
<div class="row" id="explain-contrib-interaction-{{ label['index'] }}-div-{{ col['name'] }}" style="{% if col['feature_index'] != 0 %}display:none;{% endif %} margin-right:5px; margin-left:5px;">
<h4>{{ col['name'] }} - {{ col['type'] }}</h4>
{% if col['name'] != col['description'] %}
<blockquote class="panel-content">{{ col['description'] }}</blockquote>
{% else %}
{% endif %}
{{ col['plot'] }}
</div>
{% endfor %}
</div>
{% endfor %}
50 changes: 48 additions & 2 deletions shapash/report/project_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
from shapash.utils.io import load_yml
from shapash.utils.transform import apply_postprocessing, handle_categorical_missing, inverse_transform
from shapash.utils.utils import get_project_root, truncate_str
from shapash.utils.utils import compute_sorted_variables_interactions_list_indices, get_project_root, truncate_str
from shapash.webapp.utils.utils import round_to_k

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -98,6 +98,16 @@ def __init__(
self.y_train, target_name_train = self._get_values_and_name(y_train, "target")
self.target_name = target_name_train or target_name_test

if "max_points" in self.config.keys():
self.max_points = config["max_points"]
else:
self.max_points = 200

if "nb_top_interactions" in self.config.keys():
self.nb_top_interactions = config["nb_top_interactions"]
else:
self.nb_top_interactions = 5

if "title_story" in self.config.keys():
self.title_story = config["title_story"]
elif self.explainer.title_story != "":
Expand Down Expand Up @@ -308,6 +318,7 @@ def display_dataset_analysis(
print_md("### Multivariate analysis")
fig_corr = self.explainer.plot.correlations(
self.df_train_test,
optimized=True,
facet_col="data_train_test",
max_features=20,
width=900 if len(self.df_train_test["data_train_test"].unique()) > 1 else 500,
Expand Down Expand Up @@ -389,13 +400,16 @@ def display_model_explainability(self):
c_list = self.explainer._classes if multiclass else [1] # list just used for multiclass
for index_label, label in enumerate(c_list): # Iterating over all labels in multiclass case
label_value = self.explainer.check_label_name(label)[2] if multiclass else ""

# Feature Importance
fig_features_importance = self.explainer.plot.features_importance(label=label)

# Contribution Plot
explain_contrib_data = list()
list_cols_labels = [self.explainer.features_dict.get(col, col) for col in self.col_names]
for feature_label in sorted(list_cols_labels):
feature = self.explainer.inv_features_dict.get(feature_label, feature_label)
fig = self.explainer.plot.contribution_plot(feature, label=label, max_points=200)
fig = self.explainer.plot.contribution_plot(feature, label=label, max_points=self.max_points)
# Apparently matkers are not supported during conversion into html
for el in fig.data:
if el.type == "bar":
Expand All @@ -408,6 +422,37 @@ def display_model_explainability(self):
"plot": plotly.io.to_html(fig, include_plotlyjs=False, full_html=False),
}
)

# Interaction Plot
explain_contrib_data_interaction = list()
list_ind, _ = self.explainer.plot._select_indices_interactions_plot(
selection=None, max_points=self.max_points
)
interaction_values = self.explainer.get_interaction_values(selection=list_ind)
sorted_top_features_indices = compute_sorted_variables_interactions_list_indices(interaction_values)
indices_to_plot = sorted_top_features_indices[: self.nb_top_interactions]

for i, ids in enumerate(indices_to_plot):
id0, id1 = ids

fig_one_interaction = self.explainer.plot.interactions_plot(
col1=self.explainer.columns_dict[id0],
col2=self.explainer.columns_dict[id1],
max_points=self.max_points,
)

explain_contrib_data_interaction.append(
{
"feature_index": i,
"name": self.explainer.columns_dict[id0] + " / " + self.explainer.columns_dict[id1],
"description": self.explainer.features_dict[self.explainer.columns_dict[id0]]
+ " / "
+ self.explainer.features_dict[self.explainer.columns_dict[id1]],
"plot": plotly.io.to_html(fig_one_interaction, include_plotlyjs=False, full_html=False),
}
)

# Aggregating the data
explain_data.append(
{
"index": index_label,
Expand All @@ -416,6 +461,7 @@ def display_model_explainability(self):
fig_features_importance, include_plotlyjs=False, full_html=False
),
"features": explain_contrib_data,
"features_interaction": explain_contrib_data_interaction,
}
)
print_html(explainability_template.render(labels=explain_data))
Expand Down

0 comments on commit 62703a2

Please sign in to comment.