-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Scalability over larger datasets #5
Comments
Hi, I'm interested in where this problem might come from but I can't reproduce it. Can you share a minimal reproducible code example with this dataset or any other causing this problem? It's also related to #4, I believe. |
Thanks for the prompt reply. Yes, this is related to issue #4 which I posted earlier. We can reproduce this by create a random array of size (1000,8). Here's a simple example on how to reproduce it create random datasetsX=np.random.rand(1000,8) run random survival forestfrom sksurv.ensemble import RandomSurvivalForest run survshapfrom survshap import SurvivalModelExplainer,ModelSurvShap pnd_survshap_global_rsf=ModelSurvShap(random_state=42) |
Hi, |
Hi, @krzyzinskim Were you able to reproduce this? |
Hi @Addicted-to-coding @solidate, it's expected to be slow. The implemented (default) algorithm aims to "exactly" approximate Shapley values and therefore is useful for relatively small (background) datasets. So you can probably compute SurvSHAP(t) for 1000+ samples, but when using 100-200 samples as the background for estimation. Another way to speed up calculations is to reduce the number of timestamps (parameter Also, RSF has a slow inference adding to the time. See the comparison with a simpler CPH model. import numpy as np
import pandas as pd
from survshap import SurvivalModelExplainer, ModelSurvSHAP
X=np.random.rand(1000,8)
y=np.random.rand(1000,1)
boo=np.random.choice(a=[True,False],size=(1000,1),p=[0.5,0.5])
out=np.empty(1000,dtype=[('event','?'),('time','<f8')])
out['event']=boo.reshape(-1)
out['time']=y.reshape(-1)
X=pd.DataFrame(X,columns=['f1','f2','f3','f4','f5','f6','f7','f8'])
from sksurv.linear_model import CoxPHSurvivalAnalysis
cph = CoxPHSurvivalAnalysis()
cph.fit(X, out)
cph.score(X, out)
from sksurv.ensemble import RandomSurvivalForest
rsf=RandomSurvivalForest(random_state=42,n_estimators=120, max_depth=8,max_features=3)
rsf.fit(X,out)
rsf.score(X,out)
exp_cph=SurvivalModelExplainer(cph,X,out)
ms_cph=ModelSurvSHAP(random_state=42)
ms_cph.fit(exp_cph)
exp_rsf=SurvivalModelExplainer(rsf,X,out)
ms_rsf=ModelSurvSHAP(random_state=42)
ms_rsf.fit(exp_rsf) |
Hi,
Is your method scalable over larger datasets? I tried running this method on a dataset of size (10000, 8) and got an estimated run time as below. This should not be the case since your own test dataset is of size (300,8) and the time per iteration is low. Are you retraining the model for computing the shape values for each example? It is not clear to me why the time per iteration has increased so much given the number of features is the same.
The text was updated successfully, but these errors were encountered: