Large Dataset Struggles #8279
Unanswered
vmiller987
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm trying to understand how people are handling data that is too large to fit on the GPU.
My work deals with 3D segmentations. I have multiple GPUs with 24GB VRAM each.
I've made a Lightning model for UNet. This works fairly well for most of our datasets except our large datasets.
I use RandCropByPosNegLabel for a training transforms to perform patch based training during the training step.
I use sliding_window_inference during validation/prediction. This was working until the large dataset gets involved.
Sliding window inference returns the entire images logits/predictions. This is too large to fit on the GPU and causes OOM errors. I have to return the logits to CPU and finish performing the loss function on the CPU.
Is there something I might be misunderstanding or missing with SWI? Is there an easier way to perform the loss function while keeping it on the GPU?
The other solution I've found is to use the spacing transform and down sample the image, but I would like to be able to train at full resolution.
I tried to implement GridPatchDataset, but this is so extremely slow. It's nearly twice as slow as performing the loss functions on CPU. To me, it looks like it's pausing in between image patches to generate new patches for the next image to be evaluated. If I implement the cache functionality, then it becomes 2-3x slower, and ends up consuming 300-500gb of RAM.
Any tips or advice would be greatly appreciated. I do not have anyone with experience in machine learning to speak with and this certainly makes work harder.
Beta Was this translation helpful? Give feedback.
All reactions