-
-
Notifications
You must be signed in to change notification settings - Fork 39
/
12-Local-diagnostic.Rmd
executable file
·239 lines (159 loc) · 22.1 KB
/
12-Local-diagnostic.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# Local-diagnostics Plots {#localDiagnostics}
```{r, echo=FALSE, warning=FALSE}
source("code_snippets/ema_init.R")
```
## Introduction {#cPLocDiagIntro}
It may happen that, despite the fact that the predictive performance of a model is satisfactory overall, the model's predictions for some observations are drastically worse. In such a situation it is often said that "the model does not cover well some areas of the input space".\index{Model | diagnostics}
For example, a model fitted to the data for "typical" patients in a certain hospital may not perform well for patients from another hospital with possible different characteristics. Or, a model developed to evaluate the risk of spring-holiday consumer-loans may not perform well in the case of autumn-loans for Christmas-holiday gifts.
For this reason, in case of important decisions, it is worthwhile to check how does the model behave locally for observations similar to the instance of interest.
In this chapter, we present two local-diagnostics techniques that address this issue. The first one are *local-fidelity plots* that evaluate the local predictive performance of the model around the observation of interest. The second one are *local-stability plots* that assess the (local) stability of predictions around the observation of interest. \index{Local-stability plot}
<!---
The general idea behind fidelity plots is to select a number of observations ("neighbours") from the validation dataset that are closest to the instance (observation) of interest. Then, for the selected observations, we plot CP profiles and check how stable they are. Additionally, if we know true values of the dependent variable for the selected neighbours, we may add residuals to the plot to evaluate the local fit of the model.
--->
## Intuition {#cPLocDiagIntuition}
Assume that, for the observation of interest, we have identified a set of observations from the training data with similar characteristics. We will call these similar observations "neighbours". The basic idea behind local-fidelity plots is to compare the distribution of residuals (i.e., differences between the observed and predicted value of the dependent variable; see equation \@ref(eq:modelResiduals)) for the neighbours with the distribution of residuals for the entire training dataset.
Figure \@ref(fig:profileBack2BackHist) presents histograms of residuals for the entire dataset and for a selected set of 25 neighbours for an instance of interest for the random forest model for the apartment-prices dataset (Section \@ref(model-Apartments-rf)). The distribution of residuals for the entire dataset is rather symmetric and centred around 0, suggesting a reasonable overall performance of the model. On the other hand, the residuals for the selected neighbours are centred around the value of 500. This suggests that, for the apartments similar to the one of interest, the model is biased towards values smaller than the observed ones (residuals are positive, so, on average, the observed value of the dependent variable is larger than the predicted value).
(ref:profileBack2BackHistDesc) Histograms of residuals for the random forest model `apartments_rf` for the apartment-prices dataset. Upper panel: residuals calculated for all observations from the dataset. Bottom panel: residuals calculated for 25 nearest neighbours of the instance of interest.
```{r profileBack2BackHist, echo=FALSE, fig.cap='(ref:profileBack2BackHistDesc)', out.width = '100%', fig.align='center'}
knitr::include_graphics("figure/bb_hist.png")
```
The idea behind local-stability plots is to check whether small changes in the explanatory variables, as represented by the changes within the set of neighbours, have got much influence on the predictions. Figure \@ref(fig:profileWith10NN) presents CP profiles for variable *age* for an instance of interest and its 10 nearest neighbours for the random forest model for the Titanic dataset (Section \@ref(model-titanic-rf)). The profiles are almost parallel and very close to each other. In fact, some of them overlap so that only 5 different ones are visible. This suggests that the model's predictions are stable around the instance of interest. Of course, CP profiles for different explanatory variables may be very different, so a natural question is: which variables should we examine? The obvious choice is to focus on the variables that are the most important according to a variable-importance measure such as the ones discussed in Chapters \@ref(breakDown), \@ref(shapley), \@ref(LIME), or \@ref(ceterisParibusOscillations).
(ref:profileWith10NNDesc) Ceteris-paribus profiles for a selected instance (dark violet line) and 10 nearest neighbours (light grey lines) for the random forest model for the Titanic data. <!---The profiles are almost parallel and close to each other what suggests the stability of the model's predictions around the selected instance. In fact, some of them overlap so that only 5 different ones are visible. --->
```{r profileWith10NN, echo=FALSE, fig.cap='(ref:profileWith10NNDesc)', out.width = '60%', fig.align='center'}
knitr::include_graphics("figure/example_cp.png")
```
## Method {#cPLocDiagMethod}
To construct local-fidelity or local-stability plots, we have got to, first, select "neighbours" of the observation of interest. Then, for the fidelity analysis, we have got to calculate and compare residuals for the neighbours. For the stability analysis, we have got to calculate and visualize CP profiles for the neighbours.
In what follows, we discuss each of these steps in more detail.
### Nearest neighbours {#cPLocDiagneighbours}
There are two important questions related to the selection of the neighbours "nearest" to the instance (observation) of interest:
* How many neighbours should we choose?
* What metric should be used to measure the "proximity" of observations?
The answer to both questions is *it depends*.
The smaller the number of neighbours, the more local is the analysis. However, selecting a very small number will lead to a larger variability of the results. In many cases we found that having about 20 neighbours works fine. However, one should always take into account computational time (because a smaller number of neighbours will result in faster calculations) and the size of the dataset (because, for a small dataset, a smaller set of neighbours may be preferred).
The metric is very important. The more explanatory variables, the more important is the choice. In particular, the metric should be capable of accommodating variables of different nature (categorical, continuous). Our default choice is the Gower similarity measure:\index{Gower similarity measure}
\begin{equation}
d_{gower}(\underline{x}_i, \underline{x}_j) = \frac{1}{p} \sum_{k=1}^p d^k(x_i^k, x_j^k),
(\#eq:Gower)
\end{equation}
where $\underline{x}_i$ is a $p$-dimensional vector of values of explanatory variables for the $i$-th observation and $d^k(x_i^k,x_j^k)$ is the distance between the values of the $k$-th variable for the $i$-th and $j$-th observations. Note that $p$ may be equal to the number of all explanatory variables included in the model, or only a subset of them.
Metric $d^k()$ in \@ref(eq:Gower) depends on the nature of the variable. For a continuous variable, it is equal to
$$
d^k(x_i^k, x_j^k)=\frac{|x_i^k-x_j^k|}{\max(x_1^k,\ldots,x_n^k)-\min(x_1^k,\ldots,x_n^k)},
$$
i.e., the absolute difference scaled by the observed range of the variable. On the other hand, for a categorical variable,
$$
d^k(x_i^k, x_j^k)=1_{x_i^k = x_j^k},
$$
where $1_A$ is the indicator function for condition $A$.
An advantage of the Gower similarity measure is that it can be used for vectors with both categorical and continuous variables. A disadvantage is that it takes into account neither correlation between variables nor variable importance. For a high-dimensional setting, an interesting alternative is the proximity measure used in random forests [@randomForestBreiman], as it takes into account variable importance; however, it requires a fitted random forest model.
Once we have decided on the number of neighbours, we can use the chosen metric to select the required number of observations "closest" to the one of interest.
### Local-fidelity plot {#cPLocDiagLFplot}
Figure \@ref(fig:profileBack2BackHist) summarizes two distributions of residuals, i.e., residuals for the neighbours of the observation of interest and residuals for the entire training dataset except for neighbours.
For a typical observation, these two distributions shall be similar. An alarming situation is if the residuals for the neighbours are shifted towards positive or negative values.
Apart from visual examination, we may use statistical tests to compare the two distributions. If we do not want to assume any particular parametric form of the distributions (like, e.g., normal), we may choose non-parametric tests like the Wilcoxon test or the Kolmogorov-Smirnov test. For statistical tests, it is important that the two sets are disjointed.\index{Kolmogorov-Smirnov test}\index{Wilcoxon test}
### Local-stability plot {#cPLocDiagProfiles}
Once neighbours of the observation of interest have been identified, we can graphically compare CP profiles for selected (or all) explanatory variables.
For a model with a large number of variables, we may end up with a large number of plots. In such a case, a better strategy is to focus only on a few most important variables, selected by using a variable-importance measure (see, for example, Chapter \@ref(ceterisParibusOscillations)).
CP profiles are helpful to assess model stability.\index{Model ! stability} In addition, we can enhance the plot by adding residuals to them to allow evaluation of the local model-fit. The plot that includes CP profiles for the nearest neighbours and the corresponding residuals is called a local-stability plot.
## Example: Titanic {#cPLocDiagExample}
As an example, we will consider the prediction for Johnny D (see Section \@ref(predictions-titanic)) for the random forest model for the Titanic data (see Section \@ref(model-titanic-rf)).
Figure \@ref(fig:localStabilityPlotAge) presents a detailed explanation of the elements of a local-stability plot for *age*, a continuous explanatory variable. The plot includes eight nearest neighbours of Johnny D (see Section \@ref(predictions-titanic)). The green line shows the CP profile for the instance of interest. Profiles of the nearest neighbours are marked with grey lines. The vertical intervals correspond to residuals; the shorter the interval, the smaller the residual and the more accurate prediction of the model. Blue intervals correspond to positive residuals, red intervals to negative residuals. For an additive model, CP profiles will be approximately parallel. For a model with stable predictions, the profiles should be close to each other. This is not the case of Figure \@ref(fig:localStabilityPlotAge), in which profiles are quite apart from each other. Thus, the plot suggests potential instability of the model's predictions. Note that there are positive and negative residuals included in the plot. This indicates that, on average, the instance prediction itself should not be biased.
(ref:localStabilityPlotAgeDesc) Elements of a local-stability plot for a continuous explanatory variable. Ceteris-paribus profiles for variable *age* for Johnny D and 5 nearest neighbours for the random forest model for the Titanic data.
```{r localStabilityPlotAge, echo=FALSE, fig.cap='(ref:localStabilityPlotAgeDesc)', out.width = '70%', fig.align='center'}
knitr::include_graphics("figure/localFidelityPlots.png")
```
<!--
Figure \@ref(fig:localStabilityPlotClass) presents a local-stability plot [TOMASZ: THE HEADER SAYS LOCAL-FIDELITY.] for the categorical explanatory variable *class* for Henry and 11 nearest neighbours (identified by appropriate indices on the y-axis). Henry and his neighbours travelled in the first class. The panels in the plot show how the predicted probability of survival would change if the travel class changed. Dots indicate the original model's predictions for the neighbours, while the end of the interval corresponds to the model's prediction after changing the class. Colour of the interval indicates whether the change of the class increases or decreases the prediction. The top-left panel indicates that, for the majority of the neighbours, the change from the first to the second class reduces the predicted value of the probability of survival. On the other hand, the top-right panel indicates that changing the class from "1st" to "deck crew" increases the predicted probability.
Local-stability plots similar to the one shown in Figure \@ref(fig:localStabilityPlotClass) can help to detect interactions if we see that the same change (say, from the first to the third class) results in a different change of the model's prediction for different neighbours.
(ref:localStabilityPlotClassDesc) The local-stability plot for the categorical explanatory variable *class* for Henry and 11 nearest neighbours for the `titanic_rf` random forest model for the Titanic data.[TOMASZ: WHY 11? FOR AGE, 10 WERE USED. WHY THE PLOT HAS LOCAL-FIDELITY IN THE HEADER? IT DOES NOT SHOW RESIDUALS. STATIC FIGURE, CHECK IF CORRESPONNDS TO THE ACTUAL MODEL.]
```{r localStabilityPlotClass, echo=FALSE, fig.cap='(ref:localStabilityPlotClassDesc)', out.width = '70%', fig.align='center'}
knitr::include_graphics("figure/cp_fidelity_2.png")
```
-->
## Pros and cons {#cPLocDiagProsCons}
Local-stability plots may be very helpful to check if the model is locally additive, as for such models the CP profiles should be parallel. Also, the plots can allow assessment whether the model is locally stable, as in that case, the CP profiles should be close to each other. Local-fidelity plots are useful in checking whether the model-fit for the instance of interest is unbiased, as in that case the residuals should be small and their distribution should be symmetric around 0.
The disadvantage of both types of plots is that they are quite complex and lack objective measures of the quality of the model-fit. Thus, they are mainly suitable for exploratory analysis.
## Code snippets for R {#cPLocDiagR}
In this section, we present local diagnostic plots as implemented in the `DALEX` package for R. \index{Diagnostic plot}\index{Local-diagnostic plot}
For illustration, we use the random forest model `titanic_rf` (Section \@ref(model-titanic-rf)). The model was developed to predict the probability of survival after sinking of Titanic. Instance-level explanations are calculated for Henry, a 47-year-old male passenger that travelled in the first class (see Section \@ref(predictions-titanic)).
We first retrieve the `titanic_rf` model-object and the data frame for Henry via the `archivist` hooks, as listed in Section \@ref(ListOfModelsTitanic). We also retrieve the version of the `titanic` data with imputed missing values.
```{r, warning=FALSE, message=FALSE, eval=FALSE}
titanic_imputed <- archivist::aread("pbiecek/models/27e5c")
titanic_rf <- archivist:: aread("pbiecek/models/4e0fc")
henry <- archivist::aread("pbiecek/models/a6538")
```
```{r, echo=FALSE}
henry
```
Then we construct the explainer for the model by using function `explain()` from the `DALEX` package (see Section \@ref(ExplainersTitanicRCode)). We also load the `randomForest` package, as the model was fitted by using function `randomForest()` from this package (see Section \@ref(model-titanic-rf)) and it is important to have the corresponding `predict()` function available. The model's prediction for Henry is obtained with the help of that function.
```{r, warning=FALSE, message=FALSE, echo = TRUE, eval = FALSE}
library("randomForest")
library("DALEX")
explain_rf <- DALEX::explain(model = titanic_rf,
data = titanic_imputed[, -9],
y = titanic_imputed$survived == "yes",
label = "Random Forest")
predict(explain_rf, henry)
```
```{r, warning=FALSE, message=FALSE, echo = FALSE, eval = TRUE}
library("randomForest")
library("DALEX")
explain_rf <- DALEX::explain(model = titanic_rf,
data = titanic_imputed[, -9],
y = titanic_imputed$survived == "yes",
label = "Random Forest",
verbose = FALSE)
predict(explain_rf, henry)
```
To construct a local-fidelity plot similar to the one shown Figure \@ref(fig:profileBack2BackHist), we can use the `predict_diagnostics()` function from the `DALEX` package. The main arguments of the function are `explainer`, which specifies the name of the explainer-object for the model to be explained, and `new_observation`, which specifies the name of the data frame for the instance for which prediction is of interest. Additional useful arguments are `neighbours`, which specifies the number of observations similar to the instance of interest to be selected (default is 50), and `distance`, the function used to measure the similarity of the observations (by default, the Gower similarity measure is used). Note that function `predict_diagnostics()` has to compute residuals. Thus, we have got to specify the `y` and `residual_function` arguments when using function `explain()` to create the explainer-object (see Section \@ref(ExplainersTitanicRCode)). If the `residual_function` argument is applied with the default `NULL` value, then model residuals are calculated as in \@ref(eq:modelResiduals).
In the code below, we perform computations for the random forest model `titanic_rf` and Henry. We select 100 "neighbours" of Henry by using the (default) Gower similarity measure.
```{r residualDistributionFidelityPlot, warning=FALSE, message=FALSE, eval=TRUE, fig.width=7, fig.height=4, out.width = '70%', fig.align='center'}
id_rf <- predict_diagnostics(explainer = explain_rf,
new_observation = henry,
neighbours = 100)
id_rf
```
The resulting object is of class `predict_diagnostics`. It is a list of several components that includes, among others, histograms summarizing the distribution of residuals for the entire training dataset and for the neighbours, as well as the result of the Kolmogorov-Smirnov test comparing the two distributions. The test result is given by default when the object is printed out. In our case, it suggests a statistically significant difference between the two distributions. We can use the `plot()` function to compare the distributions graphically. The resulting graph is shown in Figure \@ref(fig:localFidelityPlotResHenry). The plot suggests that the distribution of the residuals for Henry's neighbours might be slightly shifted towards positive values, as compared to the overall distribution.
(ref:localFidelityPlotResHenryDesc) The local-fidelity plot for the random forest model for the Titanic data and passenger Henry with 100 neighbours.
```{r, warning=FALSE, message=FALSE, eval=FALSE}
plot(id_rf)
```
```{r localFidelityPlotResHenry, warning=FALSE, message=FALSE, echo=FALSE, fig.width=9.5, fig.height=4, fig.cap='(ref:localFidelityPlotResHenryDesc)', out.width = '100%', fig.align='center'}
plot(id_rf)+ theme_ema
```
Function `predict_diagnostics()` can be also used to construct local-stability plots. Toward this aim, we have got to select the explanatory variable, for which we want to create the plot. We can do it by passing the name of the variable to the `variables` argument of the function. In the code below, we first calculate CP profiles and residuals for *age* and 10 neighbours of Henry.
<!---
Toward this aim, we use the `y` argument in the `individual_profile()` function. The argument takes numerical values. Our binary dependent variable `survived` assumes values `yes/no`; to convert them to numerical values, we use the `survived == "yes"` expression.
---->
```{r localStabilityAgeHenry, warning=FALSE, message=FALSE}
id_rf_age <- predict_diagnostics(explainer = explain_rf,
new_observation = henry,
neighbours = 10,
variables = "age")
```
By applying the `plot()` function to the resulting object, we obtain the local-stability plot shown in Figure \@ref(fig:localStabilityPlotAgeHenry). The profiles are relatively close to each other, suggesting the stability of predictions. There are more negative than positive residuals, which may be seen as a signal of a (local) positive bias of the predictions.
(ref:localStabilityPlotAgeHenryDesc) The local-stability plot for variable *age* in the random forest model for the Titanic data and passenger Henry with 10 neighbours. Note that some profiles overlap, so the graph shows fewer lines.
```{r, warning=FALSE, message=FALSE, eval=FALSE}
plot(id_rf_age)
```
```{r localStabilityPlotAgeHenry, warning=FALSE, message=FALSE, echo=FALSE, fig.width=6, fig.height=4, fig.cap='(ref:localStabilityPlotAgeHenryDesc)', out.width = '60%', fig.align='center'}
plot(id_rf_age) + theme_ema
```
In the code below, we conduct the necessary calculations for the categorical variable *class* and 10 neighbours of Henry.
```{r localStabilityClassHenry, warning=FALSE, message=FALSE}
id_rf_class <- predict_diagnostics(explainer = explain_rf,
new_observation = henry,
neighbours = 10,
variables = "class")
```
By applying the `plot()` function to the resulting object, we obtain the local-stability plot shown in Figure \@ref(fig:localStabilityPlotClassHenry). The profiles are not parallel, indicating non-additivity of the effect. However, they are relatively close to each other, suggesting the stability of predictions.
(ref:localStabilityPlotClassHenryDesc) The local-stability plot for variable *class* in the random forest model for the Titanic data and passenger Henry with 10 neighbours.
```{r, warning=FALSE, message=FALSE, eval=FALSE}
plot(id_rf_class)
```
```{r localStabilityPlotClassHenry, echo=FALSE, warning=FALSE, message=FALSE, eval=FALSE, fig.width=6, fig.height=4, fig.cap='(ref:localStabilityPlotClassHenryDesc)', out.width = '60%', fig.align='center'}
plot(id_rf_class) + theme_ema
```
## Code snippets for Python {#cPLocDiagPython}
At this point we are not aware of any Python libraries that would implement the methods presented in the current chapter.