From be0d893f08253d8f987c80ff2f294cd90a0cbef8 Mon Sep 17 00:00:00 2001 From: Rahul Parundekar Date: Sun, 5 May 2024 11:16:14 -0700 Subject: [PATCH] load distributed config --- launch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/launch.py b/launch.py index 7dc0fa2..6c80203 100644 --- a/launch.py +++ b/launch.py @@ -9,8 +9,8 @@ def train(training_config_file: str = "/mnt/config/training/config.yaml") -> None: """Run Training.""" - training_config = TrainingJob.load(training_config_file, is_distributed=int(os.getenv("WORLD_SIZE", 1)) > 1) - TrainingJobRunner(training_config).run() + training_config = TrainingJob.load(training_config_file) + TrainingJobRunner(training_config, is_distributed=int(os.getenv("WORLD_SIZE", 1)) > 1).run() def infer(batch_inference_config_file: str = "/mnt/config/batch_inference/config.yaml") -> None: