-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Time predictions from decision_tree()
with the "rpart" engine correct?
#331
Comments
@jesusherranz Thank you for opening this issue and the reprex! I think this is not an issue with the concordance per se but rather with the predictions for this engine (so I'll change the title of the issue). The concordance index is different because the predictions are so different. Those might be so different because the models are quite different. But even when the models are more similar, the predictions are very different. So I currently assume it might be a bug with how we generate the predictions. library(tidymodels)
library(censored)
#> Loading required package: survival
library(smoothHR)
#> Loading required package: splines
whas500 <- whas500 %>%
select(age, gender, hr, sysbp, diasbp, bmi, cvd, afb,
sho, chf, av3, miord, mitype, lenfol, fstat) %>%
mutate(surv_var = Surv(lenfol, fstat), .keep = "unused")
rpart_spec <-
decision_tree() %>%
set_engine("rpart") %>%
set_mode("censored regression")
rpart_fit <- rpart_spec %>%
fit(surv_var ~ ., data = whas500)
rpart_pred <- augment(rpart_fit, whas500, type = "time", eval_time = 100)
partykit_spec <-
decision_tree() %>%
set_engine("partykit") %>%
set_mode("censored regression")
partykit_fit <- partykit_spec %>%
fit(surv_var ~ ., data = whas500)
partykit_pred <- augment(partykit_fit, whas500, type = "time", eval_time = 100)
# comparisons
concordance_survival(rpart_pred, truth = surv_var, estimate = .pred_time )
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 concordance_survival standard 0.198
concordance_survival(partykit_pred, truth = surv_var, estimate = .pred_time )
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 concordance_survival standard 0.745
# the metric is different because the predictions are so different
rpart_pred %>% select(surv_var, .pred_time)
#> # A tibble: 500 × 2
#> surv_var .pred_time
#> <Surv> <dbl>
#> 1 2178+ 0.695
#> 2 2172+ 0.175
#> 3 2190+ 0.175
#> 4 297 0.577
#> 5 2131+ 0.0296
#> 6 1 1.61
#> 7 2122+ 0.180
#> 8 1496 0.175
#> 9 920 3.51
#> 10 2175+ 0.175
#> # ℹ 490 more rows
partykit_pred %>% select(surv_var, .pred_time)
#> # A tibble: 500 × 2
#> surv_var .pred_time
#> <Surv> <dbl>
#> 1 2178+ 1174
#> 2 2172+ 2358
#> 3 2190+ 2358
#> 4 297 2350
#> 5 2131+ 2358
#> 6 1 2358
#> 7 2122+ 2358
#> 8 1496 2358
#> 9 920 187
#> 10 2175+ 2358
#> # ℹ 490 more rows
# the predictions might be so different because the models are quite different
rpart_fit %>% extract_fit_engine()
#> $rpart
#> n= 500
#>
#> node), split, n, deviance, yval
#> * denotes terminal node
#>
#> 1) root 500 695.561200 1.00000000
#> 2) age< 71.5 244 229.506500 0.38854860
#> 4) chf< 0.5 195 138.284600 0.22717860
#> 8) hr< 69.5 62 1.940894 0.02955276 *
#> 9) hr>=69.5 133 120.814900 0.33623700
#> 18) sysbp>=108 120 75.364020 0.24836930
#> 36) sysbp< 169.5 93 40.381390 0.17482940 *
#> 37) sysbp>=169.5 27 29.875960 0.57382780
#> 74) diasbp>=99.5 18 7.528452 0.18015450 *
#> 75) diasbp< 99.5 9 9.935203 1.82297300 *
#> 19) sysbp< 108 13 31.619030 1.60516100 *
#> 5) chf>=0.5 49 62.166150 1.11613100
#> 10) miord< 0.5 30 30.182480 0.57699630 *
#> 11) miord>=0.5 19 17.010920 2.44126100 *
#> 3) age>=71.5 256 355.500100 1.80819600
#> 6) age< 85.5 189 243.593000 1.41409700
#> 12) chf< 0.5 120 151.770400 1.04087300
#> 24) hr< 101 95 119.871800 0.87393850
#> 48) bmi>=20.87589 79 92.194110 0.69519550 *
#> 49) bmi< 20.87589 16 18.493300 2.00204100 *
#> 25) hr>=101 25 25.614090 1.80146000 *
#> 13) chf>=0.5 69 74.993720 2.27279500 *
#> 7) age>=85.5 67 81.206960 3.50931200 *
#>
#> $survfit
#>
#> Call: prodlim::prodlim(formula = form, data = data)
#> Stratified Kaplan-Meier estimator for the conditional event time survival function
#> Discrete predictor variable: rpartFactor (0.0295527632560162, 0.174829429431531, 0.180154466046163, 0.576996279039068, 0.695195545979456, 1.6051614128208, 1.80145950646331, 1.82297331997917, 2.00204080181348, 2.27279546726114, 2.44126127657507, 3.50931210269065)
#>
#> Right-censored response of a survival model
#>
#> No.Observations: 500
#>
#> Pattern:
#> Freq
#> event 215
#> right.censored 285
#>
#> $levels
#> [1] "0.0295527632560162" "0.174829429431531" "0.180154466046163"
#> [4] "0.576996279039068" "0.695195545979456" "1.6051614128208"
#> [7] "1.80145950646331" "1.82297331997917" "2.00204080181348"
#> [10] "2.27279546726114" "2.44126127657507" "3.50931210269065"
#>
#> attr(,"class")
#> [1] "pecRpart"
partykit_fit %>% extract_fit_engine()
#>
#> Model formula:
#> surv_var ~ age + gender + hr + sysbp + diasbp + bmi + cvd + afb +
#> sho + chf + av3 + miord + mitype
#>
#> Fitted party:
#> [1] root
#> | [2] age <= 72
#> | | [3] chf <= 0: 2358.000 (n = 204)
#> | | [4] chf > 0
#> | | | [5] miord <= 0: 2350.000 (n = 33)
#> | | | [6] miord > 0: 400.500 (n = 20)
#> | [7] age > 72
#> | | [8] chf <= 0: 1174.000 (n = 141)
#> | | [9] chf > 0: 187.000 (n = 102)
#>
#> Number of inner nodes: 4
#> Number of terminal nodes: 5
# refit the model with the same predictors (to have more similar models)
rpart_fit_age <- rpart_spec %>%
fit(surv_var ~ age + chf, data = whas500)
rpart_pred_age <- augment(rpart_fit_age, whas500, type = "time", eval_time = 100)
rpart_pred_age %>% select(surv_var, .pred_time)
#> # A tibble: 500 × 2
#> surv_var .pred_time
#> <Surv> <dbl>
#> 1 2178+ 1.04
#> 2 2172+ 0.227
#> 3 2190+ 0.227
#> 4 297 1.12
#> 5 2131+ 0.227
#> 6 1 0.227
#> 7 2122+ 0.227
#> 8 1496 0.227
#> 9 920 3.51
#> 10 2175+ 0.227
#> # ℹ 490 more rows Created on 2024-09-06 with reprex v2.1.0 |
decision_tree()
with the "rpart" engine correct?
Thank you very much @hfrick for your quick response and for simplifying the problem. Clearly, it is a problem with the predictions of "time", which seem to be measured in other units in rpart, they must have a meaning other than "time". |
When I run decision_tree() with the "rpart" engine, I see that the results of the tune() function's concordance index and also the calculation on a test sample give results less than 0.5, when I expect them to be greater than 0.5. I provide an example with a well-known survival data file (Hosmer), extracted from the "smoothHR" package.
I have tested with other files and the results are similar. Also, if I repeat this script changing the engine to "partykit", I get a concordance index greater than 0.7, which is correct.
The text was updated successfully, but these errors were encountered: