-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cross validation built into the framework #20544
Comments
hey @svechinsky this is a nice way to do it: https://gist.github.com/ashleve/ac511f08c0d29e74566900fd3efbb3ec import lightning as L
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from sklearn.model_selection import KFold
class ProteinsKFoldDataModule(L.LightningDataModule):
def __init__(
self,
data_dir: str = "data/",
k: int = 1, # fold number
split_seed: int = 12345, # split needs to be always the same for correct cross validation
num_splits: int = 10,
batch_size: int = 32,
num_workers: int = 0,
pin_memory: bool = False
):
super().__init__()
# this line allows to access init params with 'self.hparams' attribute
self.save_hyperparameters(logger=False)
# num_splits = 10 means our dataset will be split to 10 parts
# so we train on 90% of the data and validate on 10%
assert 1 <= self.k <= self.num_splits, "incorrect fold number"
# data transformations
self.transforms = None
self.data_train: Optional[Dataset] = None
self.data_val: Optional[Dataset] = None
@property
def num_node_features() -> int:
return 4
@property
def num_classes() -> int:
return 2
def setup(self, stage=None):
if not self.data_train and not self.data_val:
dataset_full = TUDataset(self.hparams.data_dir, name="PROTEINS", use_node_attr=True, transform=self.transforms)
# choose fold to train on
kf = KFold(n_splits=self.hparams.num_splits, shuffle=True, random_state=self.hparams.split_seed)
all_splits = [k for k in kf.split(dataset_full)]
train_indexes, val_indexes = all_splits[self.hparams.k]
train_indexes, val_indexes = train_indexes.tolist(), val_indexes.tolist()
self.data_train, self.data_val = dataset_full[train_indexes], dataset_full[val_indexes]
def train_dataloader(self):
return DataLoader(dataset=self.data_train, batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, shuffle=True)
def val_dataloader(self):
return DataLoader(dataset=self.data_val, batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory) results = []
nums_folds = 10
split_seed = 12345
for k in range(nums_folds):
datamodule = ProteinsKFoldDataModule(k=k, num_folds=num_folds, split_seed=split_seed, ...)
datamodule.prepare_data()
datamodule.setup()
# here we train the model on given split...
model = ...
...
trainer = L.Trainer(...)
trainer.fit(model, datamodule)
results.append(score)
score = sum(results) / num_folds Typically you'd instantiate a separate model and trainer for each fold, which is not too bad. We could add a version of this example to the docs. Here's an example we already have with Fabric: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/fabric/kfold_cv What do you think? |
That exactly what I ended up doing it! |
Makes sense, what would be your ideal CLI design in this case? |
Description & Motivation
Cross validation is standard practice in many cases.
I personally use it to have higher confidence in models with borderline amounts of data.
There is currently no way built in way to do this using lightning/lightning CLI
What I'm currently doing is passing the fold index manually as an argument and then the datamodule handle the fold creation.
Pitch
The ideal scenario in my opinion would be to have fold_enabled data modules that generate different train/validation/test sets based on a fold index parameter passed by lightning.
This will allow the user to maintain full control over fold selection and splitting while treating the cross fold validation as a single run.
Alternatives
No response
Additional context
No response
cc @lantiga @Borda
The text was updated successfully, but these errors were encountered: