Custom train loop to perform partial batch updates #20319
Unanswered
leonardcaquot94
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi there,
I'm working on a recursive data processing task in PyTorch Lightning, where the batch size decreases with each iteration as predictions are made until no more are needed. I want to address two key issues:
I think I need to change how training_step and validation_step work. Instead of iterating over a single batch until all recursive predictions are made, these methods should perform just one step at a time. After each step, completed predictions should be removed from the batch, while the remaining elements stay for further processing. This way, I can keep the batch size more consistent and improve efficiency.
Any idea that can help ?
Beta Was this translation helpful? Give feedback.
All reactions