Skip to content

Commit

Permalink
Merge pull request #597 from guillaume-vignal/feature/fix_bug_2.7.1
Browse files Browse the repository at this point in the history
fix bugs
  • Loading branch information
guillaume-vignal authored Oct 17, 2024
2 parents 674ead7 + c8fd7bd commit 523156e
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ where = ["."]


[tool.setuptools.package-data]
"*" = ["*.csv", "*json", "*.yml", "*.css", "*.js", "*.png"]
"*" = ["*.csv", "*json", "*.yml", "*.css", "*.js", "*.png", "*.ico"]

[tool.pytest.ini_options]
pythonpath = ["."]
Expand Down
6 changes: 4 additions & 2 deletions shapash/plots/plot_bar_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,16 @@ def plot_bar_chart(
bars = []
for num, expl in enumerate(zip(var_dict, x_val, contrib)):
feat_name, x_val_el, contrib_value = expl
group_name = inv_features_dict.get(feat_name)
is_grouped = False
if x_val_el == "":
ylabel = f"<i>{feat_name}</i>"
hoverlabel = f"<b>{feat_name}</b>"
else:
# If bar is a group of features, hovertext includes the values of the features of the group
# And color changes
group_name = inv_features_dict.get(feat_name)
if features_groups is not None and group_name in features_groups.keys() and len(index_value) > 0:
is_grouped = True
feat_groups_values = x_init[features_groups[group_name]].loc[index_value[0]]
hoverlabel = "<br />".join(
[
Expand Down Expand Up @@ -146,7 +148,7 @@ def plot_bar_chart(
color = -1 if x_val_el != "" else -2

# If the bar is a group of features we modify the color
if group_name is not None:
if is_grouped:
bar_color = style_dict["featureimp_groups"][0] if color == 1 else style_dict["featureimp_groups"][1]
else:
bar_color = dict_local_plot_colors[color]["color"]
Expand Down
1 change: 0 additions & 1 deletion shapash/plots/plot_scatter_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def _prediction_classification_plot(
df_pred = pd.concat(
[y_proba_values.reset_index(), y_pred.reset_index(drop=True), target.reset_index(drop=True)], axis=1
)
print(df_pred)
df_pred.set_index(df_pred.columns[0], inplace=True)
df_pred.columns = ["proba_values", "predict_class", "target"]
df_pred["wrong_predict"] = 1
Expand Down
8 changes: 6 additions & 2 deletions shapash/webapp/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,8 +833,12 @@ def determine_total_pages_and_display(
Tuple[int, str, int]: Total pages, display properties, and updated page number.
"""
display_groups = explainer.features_groups is not None and bool_group
nb_features = len(explainer.features_imp_groups) if display_groups else len(explainer.features_imp)
total_pages = nb_features // features + 1
if explainer._case == "classification":
nb_features = len(explainer.features_imp_groups[0]) if display_groups else len(explainer.features_imp[0])
elif explainer._case == "regression":
nb_features = len(explainer.features_imp_groups) if display_groups else len(explainer.features_imp)

total_pages = (nb_features - 1) // features + 1
if (total_pages == 1) or (group_name):
display_page = {"display": "none"}
else:
Expand Down

0 comments on commit 523156e

Please sign in to comment.