diff --git a/verl/utils/tracking.py b/verl/utils/tracking.py index 5cbc58da..3d0499ac 100644 --- a/verl/utils/tracking.py +++ b/verl/utils/tracking.py @@ -22,7 +22,7 @@ class Tracking(object): - supported_backend = ["wandb", "mlflow", "swanlab", "vemlp_wandb", "console"] + supported_backend = ["wandb", "mlflow", "swanlab", "vemlp_wandb", "tensorboard", "console"] def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = 'console', config=None): if isinstance(default_backend, str): @@ -81,6 +81,9 @@ def __init__(self, project_name, experiment_name, default_backend: Union[str, Li ) self.logger['vemlp_wandb'] = vemlp_wandb + if 'tensorboard' in default_backend: + self.logger['tensorboard'] = _TensorboardAdapter() + if 'console' in default_backend: from verl.utils.logger.aggregate_logger import LocalLogger self.console_logger = LocalLogger(print_to_console=True) @@ -98,6 +101,26 @@ def __del__(self): self.logger['swanlab'].finish() if 'vemlp_wandb' in self.logger: self.logger['vemlp_wandb'].finish(exit_code=0) + if 'tensorboard' in self.logger: + self.logger['tensorboard'].finish() + + +class _TensorboardAdapter: + + def __init__(self): + from torch.utils.tensorboard import SummaryWriter + import os + tensorboard_dir = os.environ.get("TENSORBOARD_DIR", "tensorboard_log") + os.makedirs(tensorboard_dir, exist_ok=True) + print(f"Saving tensorboard log to {tensorboard_dir}.") + self.writer = SummaryWriter(tensorboard_dir) + + def log(self, data, step): + for key in data: + self.writer.add_scalar(key, data[key], step) + + def finish(self): + self.writer.close() class _MlflowLoggingAdapter: