Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time predictions from decision_tree() with the "rpart" engine correct? #331

Open
jesusherranz opened this issue Sep 5, 2024 · 2 comments
Labels
bug an unexpected problem or unintended behavior

Comments

@jesusherranz
Copy link

jesusherranz commented Sep 5, 2024

When I run decision_tree() with the "rpart" engine, I see that the results of the tune() function's concordance index and also the calculation on a test sample give results less than 0.5, when I expect them to be greater than 0.5. I provide an example with a well-known survival data file (Hosmer), extracted from the "smoothHR" package.
I have tested with other files and the results are similar. Also, if I repeat this script changing the engine to "partykit", I get a concordance index greater than 0.7, which is correct.

library(tidymodels)
library(censored)
library(smoothHR)

## Data
whas500 <- whas500 %>% select(age, gender, hr, sysbp, diasbp, bmi, cvd, afb,
                              sho, chf, av3, miord, mitype, lenfol, fstat)

set.seed(252)
whas500_split <- initial_split(whas500, strata = fstat)
whas500_train <- training(whas500_split)
whas500_test <- testing(whas500_split)
                         
whas500_train <- whas500_train %>%
  mutate(surv_var = Surv(lenfol, fstat), .keep = "unused")
whas500_test <- whas500_test %>%
  mutate(surv_var = Surv(lenfol, fstat), .keep = "unused")

## resampling
set.seed(253)
cv_split <- vfold_cv(whas500_train, v = 10, repeats = 2 )

## Model specification
tree_spec <- 
    decision_tree( tree_depth = tune(), min_n = tune(),
                   cost_complexity = tune() ) %>%
    set_engine("rpart") %>% 
    set_mode("censored regression") 
              
## Workflow              
wflow_tree <- workflow() %>%
  add_model(tree_spec) %>% 
  add_formula(surv_var ~ . ) 
  
## Parameters Tune
tree_grid <- grid_regular(cost_complexity(), tree_depth(), min_n(),
                          levels = 4 )    

tune_result_tree <- wflow_tree %>% 
  tune_grid( resamples = cv_split, grid = tree_grid, 
             metrics = metric_set(concordance_survival) ) 
show_best(tune_result_tree, metric="concordance_survival")

## Final workflow and final mpdel
final_wflow_tree <- wflow_tree %>% 
  finalize_workflow( select_best(tune_result_tree, metric="concordance_survival") )
tree_fit <- final_wflow_tree %>% fit(whas500_train)
tree_fit

## Predictions in the testing sample
pred_tree_time <- predict(tree_fit, whas500_test, type = "time")
pred_tree_df <- bind_cols(whas500_test %>% select(surv_var), pred_tree_time ) 
head(pred_tree_df)

## Concordance
concordance_survival(pred_tree_df, truth = surv_var, estimate = .pred_time ) 

@hfrick
Copy link
Member

hfrick commented Sep 6, 2024

@jesusherranz Thank you for opening this issue and the reprex! I think this is not an issue with the concordance per se but rather with the predictions for this engine (so I'll change the title of the issue).

The concordance index is different because the predictions are so different. Those might be so different because the models are quite different. But even when the models are more similar, the predictions are very different. So I currently assume it might be a bug with how we generate the predictions.

library(tidymodels)
library(censored)
#> Loading required package: survival
library(smoothHR)
#> Loading required package: splines

whas500 <- whas500 %>% 
  select(age, gender, hr, sysbp, diasbp, bmi, cvd, afb,
   sho, chf, av3, miord, mitype, lenfol, fstat) %>%
  mutate(surv_var = Surv(lenfol, fstat), .keep = "unused")

rpart_spec <- 
    decision_tree() %>%
    set_engine("rpart") %>% 
    set_mode("censored regression") 
rpart_fit <- rpart_spec %>% 
    fit(surv_var ~ ., data = whas500)
rpart_pred <- augment(rpart_fit, whas500, type = "time", eval_time = 100)

partykit_spec <- 
    decision_tree() %>%
    set_engine("partykit") %>% 
    set_mode("censored regression")
partykit_fit <- partykit_spec %>%
    fit(surv_var ~ ., data = whas500)
partykit_pred <- augment(partykit_fit, whas500, type = "time", eval_time = 100)

# comparisons
concordance_survival(rpart_pred, truth = surv_var, estimate = .pred_time ) 
#> # A tibble: 1 × 3
#>   .metric              .estimator .estimate
#>   <chr>                <chr>          <dbl>
#> 1 concordance_survival standard       0.198
concordance_survival(partykit_pred, truth = surv_var, estimate = .pred_time )
#> # A tibble: 1 × 3
#>   .metric              .estimator .estimate
#>   <chr>                <chr>          <dbl>
#> 1 concordance_survival standard       0.745

# the metric is different because the predictions are so different
rpart_pred %>% select(surv_var, .pred_time)
#> # A tibble: 500 × 2
#>    surv_var .pred_time
#>      <Surv>      <dbl>
#>  1    2178+     0.695 
#>  2    2172+     0.175 
#>  3    2190+     0.175 
#>  4     297      0.577 
#>  5    2131+     0.0296
#>  6       1      1.61  
#>  7    2122+     0.180 
#>  8    1496      0.175 
#>  9     920      3.51  
#> 10    2175+     0.175 
#> # ℹ 490 more rows
partykit_pred %>% select(surv_var, .pred_time)
#> # A tibble: 500 × 2
#>    surv_var .pred_time
#>      <Surv>      <dbl>
#>  1    2178+       1174
#>  2    2172+       2358
#>  3    2190+       2358
#>  4     297        2350
#>  5    2131+       2358
#>  6       1        2358
#>  7    2122+       2358
#>  8    1496        2358
#>  9     920         187
#> 10    2175+       2358
#> # ℹ 490 more rows

# the predictions might be so different because the models are quite different
rpart_fit %>% extract_fit_engine()
#> $rpart
#> n= 500 
#> 
#> node), split, n, deviance, yval
#>       * denotes terminal node
#> 
#>  1) root 500 695.561200 1.00000000  
#>    2) age< 71.5 244 229.506500 0.38854860  
#>      4) chf< 0.5 195 138.284600 0.22717860  
#>        8) hr< 69.5 62   1.940894 0.02955276 *
#>        9) hr>=69.5 133 120.814900 0.33623700  
#>         18) sysbp>=108 120  75.364020 0.24836930  
#>           36) sysbp< 169.5 93  40.381390 0.17482940 *
#>           37) sysbp>=169.5 27  29.875960 0.57382780  
#>             74) diasbp>=99.5 18   7.528452 0.18015450 *
#>             75) diasbp< 99.5 9   9.935203 1.82297300 *
#>         19) sysbp< 108 13  31.619030 1.60516100 *
#>      5) chf>=0.5 49  62.166150 1.11613100  
#>       10) miord< 0.5 30  30.182480 0.57699630 *
#>       11) miord>=0.5 19  17.010920 2.44126100 *
#>    3) age>=71.5 256 355.500100 1.80819600  
#>      6) age< 85.5 189 243.593000 1.41409700  
#>       12) chf< 0.5 120 151.770400 1.04087300  
#>         24) hr< 101 95 119.871800 0.87393850  
#>           48) bmi>=20.87589 79  92.194110 0.69519550 *
#>           49) bmi< 20.87589 16  18.493300 2.00204100 *
#>         25) hr>=101 25  25.614090 1.80146000 *
#>       13) chf>=0.5 69  74.993720 2.27279500 *
#>      7) age>=85.5 67  81.206960 3.50931200 *
#> 
#> $survfit
#> 
#> Call: prodlim::prodlim(formula = form, data = data)
#> Stratified Kaplan-Meier estimator for the conditional event time survival function
#> Discrete predictor variable: rpartFactor (0.0295527632560162, 0.174829429431531, 0.180154466046163, 0.576996279039068, 0.695195545979456, 1.6051614128208, 1.80145950646331, 1.82297331997917, 2.00204080181348, 2.27279546726114, 2.44126127657507, 3.50931210269065)
#> 
#> Right-censored response of a survival model
#> 
#> No.Observations: 500 
#> 
#> Pattern:
#>                 Freq
#>  event          215 
#>  right.censored 285 
#> 
#> $levels
#>  [1] "0.0295527632560162" "0.174829429431531"  "0.180154466046163" 
#>  [4] "0.576996279039068"  "0.695195545979456"  "1.6051614128208"   
#>  [7] "1.80145950646331"   "1.82297331997917"   "2.00204080181348"  
#> [10] "2.27279546726114"   "2.44126127657507"   "3.50931210269065"  
#> 
#> attr(,"class")
#> [1] "pecRpart"
partykit_fit %>% extract_fit_engine()
#> 
#> Model formula:
#> surv_var ~ age + gender + hr + sysbp + diasbp + bmi + cvd + afb + 
#>     sho + chf + av3 + miord + mitype
#> 
#> Fitted party:
#> [1] root
#> |   [2] age <= 72
#> |   |   [3] chf <= 0: 2358.000 (n = 204)
#> |   |   [4] chf > 0
#> |   |   |   [5] miord <= 0: 2350.000 (n = 33)
#> |   |   |   [6] miord > 0: 400.500 (n = 20)
#> |   [7] age > 72
#> |   |   [8] chf <= 0: 1174.000 (n = 141)
#> |   |   [9] chf > 0: 187.000 (n = 102)
#> 
#> Number of inner nodes:    4
#> Number of terminal nodes: 5

# refit the model with the same predictors (to have more similar models)
rpart_fit_age <- rpart_spec %>% 
    fit(surv_var ~ age + chf, data = whas500)
rpart_pred_age <- augment(rpart_fit_age, whas500, type = "time", eval_time = 100)
rpart_pred_age %>% select(surv_var, .pred_time)
#> # A tibble: 500 × 2
#>    surv_var .pred_time
#>      <Surv>      <dbl>
#>  1    2178+      1.04 
#>  2    2172+      0.227
#>  3    2190+      0.227
#>  4     297       1.12 
#>  5    2131+      0.227
#>  6       1       0.227
#>  7    2122+      0.227
#>  8    1496       0.227
#>  9     920       3.51 
#> 10    2175+      0.227
#> # ℹ 490 more rows

Created on 2024-09-06 with reprex v2.1.0

@hfrick hfrick added the bug an unexpected problem or unintended behavior label Sep 6, 2024
@hfrick hfrick changed the title Concordance index with decision_tree() with the "rpart" engine Time predictions from decision_tree() with the "rpart" engine correct? Sep 6, 2024
@jesusherranz
Copy link
Author

jesusherranz commented Sep 9, 2024

Thank you very much @hfrick for your quick response and for simplifying the problem. Clearly, it is a problem with the predictions of "time", which seem to be measured in other units in rpart, they must have a meaning other than "time".
I would like to ask you a question related to this. Sometimes weather predictions with the "partykit" package provide "Infinite" as a value (I think it is the median of the survival curve), and I don't quite understand how a c-index can be calculated with this value.
Thanks in advance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug an unexpected problem or unintended behavior
Projects
None yet
Development

No branches or pull requests

2 participants