Skip to content

Commit

Permalink
Merge pull request #64 from MAIF/feature/interaction_plot
Browse files Browse the repository at this point in the history
Add interaction plot in the report
  • Loading branch information
guillaume-vignal authored Sep 20, 2024
2 parents 3a16f0f + 26131ff commit 64c30a7
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
strategy:
max-parallel: 1
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ sd.generate_report(

## 🛠 Installation

Eurybia is intended to work with Python versions 3.9 to 3.11. Installation can be done with pip:
Eurybia is intended to work with Python versions 3.9 to 3.12. Installation can be done with pip:

```
pip install eurybia
Expand Down
9 changes: 5 additions & 4 deletions eurybia/core/smartdrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import shutil
import tempfile
from pathlib import Path
from typing import Dict

import catboost
import pandas as pd
Expand Down Expand Up @@ -199,12 +198,12 @@ def __init__(
def compile(
self,
full_validation=False,
ignore_cols: list = list(),
ignore_cols: list = None,
sampling=True,
sample_size=100000,
datadrift_file=None,
date_compile_auc=None,
hyperparameter: Dict = catboost_hyperparameter_init.copy(),
hyperparameter: dict = catboost_hyperparameter_init.copy(),
attr_importance="feature_importances_",
):
r"""
Expand Down Expand Up @@ -237,6 +236,8 @@ def compile(
>>> SD.compile()
"""
if ignore_cols is None:
ignore_cols = []
if datadrift_file is not None:
self.datadrift_file = datadrift_file
if hyperparameter is not None:
Expand Down Expand Up @@ -468,7 +469,7 @@ def _analyze_consistency(self, full_validation=False, ignore_cols: list = list()
and will not be analyzed: \n {err_dtypes}"""
)
# Feature values
err_mods: Dict[str, Dict] = {}
err_mods: dict[str, dict] = {}
if full_validation is True:
invalid_cols = ignore_cols + new_cols + removed_cols + err_dtypes
for column in self.df_baseline.columns:
Expand Down
2 changes: 1 addition & 1 deletion eurybia/core/smartplotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def generate_modeldrift_data(
if data_modeldrift is None:
data_modeldrift = self.smartdrift.data_modeldrift
if data_modeldrift is None:
raise Exception(
raise ValueError(
"""You should run the add_data_modeldrift method before displaying model drift performances.
For more information see the documentation"""
)
Expand Down
16 changes: 14 additions & 2 deletions eurybia/report/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ def get_data_drift_panel(dr: DriftReport) -> pn.Column:
pn.pane.Markdown("### Univariate analysis"),
pn.pane.Markdown(report_text["Data drift"]["07"]),
]
contribution_figures, contribution_labels = dr.display_model_contribution()

distribution_figures, labels, distribution_tables = dr.display_dataset_analysis(global_analysis=False)["univariate"]
distribution_plots_blocks = get_select_plots(
Expand Down Expand Up @@ -262,6 +261,9 @@ def get_data_drift_panel(dr: DriftReport) -> pn.Column:
max_gauge=0.2,
)
blocks += [pn.pane.Plotly(js_fig)]

contribution_figures, contribution_labels = dr.display_model_contribution()

blocks += [
pn.pane.Markdown("## Feature contribution on data drift's detection"),
pn.pane.Markdown(report_text["Data drift"]["09"]),
Expand All @@ -273,14 +275,24 @@ def get_data_drift_panel(dr: DriftReport) -> pn.Column:
figures=contribution_figures,
)
blocks += contribution_plots_blocks

fig_02 = dr.explainer.plot.top_interactions_plot(nb_top_interactions=10)
fig_02.update_layout(width=1240)
blocks += [
pn.pane.Markdown("## Feature interaction on data drift's detection"),
pn.pane.Markdown(report_text["Data drift"]["10"]),
pn.pane.Plotly(fig_02),
]

if dr.smartdrift.historical_auc is not None:
fig = dr.smartdrift.plot.generate_historical_datadrift_metric()
fig.update_layout(width=1240)
blocks += [
pn.pane.Markdown("## Historical Data drift"),
pn.pane.Markdown(report_text["Data drift"]["10"]),
pn.pane.Markdown(report_text["Data drift"]["11"]),
pn.pane.Plotly(fig),
]

return pn.Column(*blocks, name="Data drift", styles=dict(display="none"), css_classes=["data-drift"])


Expand Down
12 changes: 6 additions & 6 deletions eurybia/report/project_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
import logging
import os
from typing import Dict, Optional, Union
from typing import Optional, Union

import jinja2
import pandas as pd
Expand Down Expand Up @@ -36,11 +36,11 @@ class DriftReport:
Attributes
----------
smartdrift: object
SmartDrift object
SmartDrift object
explainer : shapash.explainer.smart_explainer.SmartExplainer
A shapash SmartExplainer object that has already be compiled
A shapash SmartExplainer object that has already be compiled
title_story : str
Report title
Report title
metadata : dict
Information about the project (author, description, ...)
df_predict : pd.DataFrame
Expand All @@ -56,7 +56,7 @@ def __init__(
smartdrift: SmartDrift,
explainer: SmartExplainer,
project_info_file: Optional[str] = None,
config_report: Optional[Dict] = None,
config_report: Optional[dict] = None,
):
"""
Parameters
Expand Down Expand Up @@ -253,7 +253,7 @@ def display_model_contribution(self):
c_list = self.explainer._classes if multiclass else [1] # list just used for multiclass
plot_list = []
labels = []
for index_label, label in enumerate(c_list): # Iterating over all labels in multiclass case
for label in c_list: # Iterating over all labels in multiclass case
for feature in self.features_imp_list:
fig = self.explainer.plot.contribution_plot(feature, label=label, max_points=200)
plot_list.append(fig)
Expand Down
11 changes: 8 additions & 3 deletions eurybia/report/properties.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict
from typing import Any

report_text: Dict[str, Any] = {
report_text: dict[str, Any] = {
"Index": {
"01": "- Project information: report context and information",
"02": "- Consistency Analysis: highlighting differences between the two datasets",
Expand Down Expand Up @@ -77,7 +77,12 @@
"This representation constitutes a support to understand the drift "
"when the analysis of the dataset is unclear."
),
"10": ("Line chart showing the metrics evolution of the datadrift classifier over the given period of time."),
"10": (
"This graph represents the interactions between couple of variable to the data drift detection. "
"This representation constitutes a support to understand the drift "
"when the analysis of the dataset is unclear."
),
"11": ("Line chart showing the metrics evolution of the datadrift classifier over the given period of time."),
},
"Model drift": {
"01": (
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ classifiers = [
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
]
Expand Down

0 comments on commit 64c30a7

Please sign in to comment.