-
Notifications
You must be signed in to change notification settings - Fork 309
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #196 from AutoFairAthenaRC/master
Adding Global Actions in A Nutshell For Counterfactual Explainability (GLANCE) framework
- Loading branch information
Showing
35 changed files
with
5,986 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from abc import ABC, abstractmethod | ||
import pandas as pd | ||
import numpy as np | ||
|
||
|
||
class ClusteringMethod(ABC): | ||
""" | ||
Abstract base class for clustering methods. | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
Initialize the ClusteringMethod. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def fit(self, data: pd.DataFrame): | ||
""" | ||
Fit the clustering model on the given data. | ||
Parameters: | ||
- data (pd.DataFrame): DataFrame of input data to fit the model. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def predict(self, instances: pd.DataFrame) -> np.ndarray: | ||
""" | ||
Predict the cluster labels for the given instances. | ||
Parameters: | ||
- instances (pd.DataFrame): DataFrame of input instances. | ||
Returns: | ||
- cluster_labels (np.ndarray): Array of cluster labels for each instance. | ||
""" | ||
pass | ||
|
||
|
||
class LocalCounterfactualMethod(ABC): | ||
""" | ||
Abstract base class for local counterfactual methods. | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
Initialize the LocalCounterfactualMethod. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def fit(self, **kwargs): | ||
""" | ||
Fit the counterfactual method. | ||
Parameters: | ||
- **kwargs: Additional keyword arguments for fitting. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def explain_instances( | ||
self, instances: pd.DataFrame, num_counterfactuals: int | ||
) -> pd.DataFrame: | ||
""" | ||
Find the local counterfactuals for the given instances. | ||
Parameters: | ||
- instances (pd.DataFrame): DataFrame of input instances for which counterfactuals are desired. | ||
- num_counterfactuals (int): Number of counterfactuals to generate for each instance. | ||
Returns: | ||
- counterfactuals (pd.DataFrame): DataFrame of counterfactual instances. | ||
""" | ||
pass | ||
|
||
|
||
class GlobalCounterfactualMethod(ABC): | ||
""" | ||
Abstract base class for global counterfactual methods. | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
""" | ||
Initialize the LocalCounterfactualMethod. | ||
Parameters: | ||
- **kwargs: Additional keyword arguments for init. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def fit(self, X, y, **kwargs): | ||
""" | ||
Fit the counterfactual method. | ||
Parameters: | ||
- **kwargs: Additional keyword arguments for fitting. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def explain_group(self, instances: pd.DataFrame) -> pd.DataFrame: | ||
""" | ||
Find the global counterfactuals for the given group of instances. | ||
Parameters: | ||
- instances (pd.DataFrame, optional): DataFrame of input instances for which global counterfactuals are desired. | ||
If None, explain the whole group of affected instances. | ||
Returns: | ||
- counterfactuals (pd.DataFrame): DataFrame of counterfactual instances. | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .kmeans import KMeansMethod |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from ..base import ClusteringMethod | ||
from sklearn.cluster import KMeans | ||
|
||
|
||
class KMeansMethod(ClusteringMethod): | ||
""" | ||
Implementation of a clustering method using KMeans. | ||
This class provides an interface to apply KMeans clustering to a dataset. | ||
""" | ||
|
||
def __init__(self, num_clusters, random_seed): | ||
""" | ||
Initializes the KMeansMethod class. | ||
Parameters: | ||
---------- | ||
num_clusters : int | ||
The number of clusters to form as well as the number of centroids to generate. | ||
random_seed : int | ||
A seed for the random number generator to ensure reproducibility. | ||
""" | ||
|
||
self.num_clusters = num_clusters | ||
self.random_seed = random_seed | ||
self.model = KMeans() | ||
|
||
def fit(self, data): | ||
""" | ||
Fits the KMeans model on the provided dataset. | ||
Parameters: | ||
---------- | ||
data : array-like or sparse matrix, shape (n_samples, n_features) | ||
Training instances to cluster. | ||
Returns: | ||
------- | ||
None | ||
""" | ||
self.model = KMeans( | ||
n_clusters=self.num_clusters, n_init=10, random_state=self.random_seed | ||
) | ||
self.model.fit(data) | ||
|
||
def predict(self, instances): | ||
""" | ||
Predicts the nearest cluster each sample in the provided data belongs to. | ||
Parameters: | ||
---------- | ||
instances : array-like or sparse matrix, shape (n_samples, n_features) | ||
New data to predict. | ||
Returns: | ||
------- | ||
labels : array, shape (n_samples,) | ||
Index of the cluster each sample belongs to. | ||
""" | ||
return self.model.predict(instances) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from typing import Callable, List, Dict | ||
import numpy as np | ||
import pandas as pd | ||
|
||
|
||
def build_dist_func_dataframe( | ||
X: pd.DataFrame, | ||
numerical_columns: List[str], | ||
categorical_columns: List[str], | ||
n_bins: int = 10, | ||
) -> Callable[[pd.DataFrame, pd.DataFrame], pd.Series]: | ||
""" | ||
Builds and returns a custom distance function for computing distances between rows of two DataFrames based on specified numerical and categorical columns. | ||
For numerical columns, the values are first binned into intervals based on the provided number of bins (`n_bins`). | ||
The distance between numerical features is computed as the sum of the absolute differences between binned values. For categorical columns, the distance is calculated as the number of mismatched categorical values. | ||
Parameters: | ||
---------- | ||
X : pd.DataFrame | ||
The reference DataFrame used to determine the bin intervals for numerical columns. | ||
numerical_columns : List[str] | ||
List of column names in `X` that contain numerical features. | ||
categorical_columns : List[str] | ||
List of column names in `X` that contain categorical features. | ||
n_bins : int, optional | ||
The number of bins to use when normalizing numerical columns, by default 10. | ||
Returns: | ||
------- | ||
Callable[[pd.DataFrame, pd.DataFrame], pd.Series] | ||
A distance function that takes two DataFrames as input (`X1` and `X2`) and returns a Series of distances between corresponding rows in `X1` and `X2`. | ||
The distance function works as follows: | ||
- For numerical columns: the absolute differences between binned values are summed. | ||
- For categorical columns: the number of mismatches between values is counted. | ||
""" | ||
feat_intervals = { | ||
col: ((max(X[col]) - min(X[col])) / n_bins) for col in numerical_columns | ||
} | ||
|
||
def bin_numericals(instances: pd.DataFrame): | ||
ret = instances.copy() | ||
for col in numerical_columns: | ||
ret[col] /= feat_intervals[col] | ||
return ret | ||
|
||
def dist_f(X1: pd.DataFrame, X2: pd.DataFrame) -> pd.Series: | ||
X1 = bin_numericals(X1) | ||
X2 = bin_numericals(X2) | ||
|
||
ret = (X1[numerical_columns] - X2[numerical_columns]).abs().sum(axis="columns") | ||
ret += (X1[categorical_columns] != X2[categorical_columns]).astype(int).sum(axis="columns") | ||
|
||
return ret | ||
|
||
return dist_f | ||
|
Empty file.
Oops, something went wrong.