diff --git a/src/aihero/research/finetuning/infer.py b/src/aihero/research/finetuning/infer.py index 86552fa..617900f 100644 --- a/src/aihero/research/finetuning/infer.py +++ b/src/aihero/research/finetuning/infer.py @@ -1,4 +1,5 @@ """Module to run batch inference jobs.""" +import gc import os from pathlib import Path from tempfile import TemporaryDirectory @@ -233,6 +234,16 @@ def run(self) -> None: print("Save and Uploading model..") finish() + def cleanup(self) -> None: + """Clean up memory useage.""" + del self.model + del self.tokenizer + del self.dataset_dict + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + gc.collect() + class BatchInferenceWithEval: """Batch inference class for generating predictions and running custom tests and metrics.""" diff --git a/src/aihero/research/finetuning/train.py b/src/aihero/research/finetuning/train.py index fd1e6e8..fbbb21a 100644 --- a/src/aihero/research/finetuning/train.py +++ b/src/aihero/research/finetuning/train.py @@ -1,4 +1,5 @@ """Launch the training job inside a container.""" +import gc import os import time import traceback @@ -428,3 +429,13 @@ def run(self) -> None: print("Saving model..") self.save_model() finish() + + def cleanup(self) -> None: + """Clean up memory useage.""" + del self.model + del self.tokenizer + del self.dataset_dict + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + gc.collect()