Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

W&B wandb support #699

Merged
merged 19 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ sentencepiece
tiktoken
blobfile
tabulate
wandb
30 changes: 7 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ We report our [Performance](docs/performance.md) verified on 64/128 GPUs.
git clone https://github.com/pytorch/torchtitan
cd torchtitan
pip install -r requirements.txt
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall # or cu118
```

### Downloading a tokenizer
Expand All @@ -85,11 +85,11 @@ Once you have confirmed access, you can run the following command to download th
```bash
# Get your HF token from https://huggingface.co/settings/tokens

# Llama 3 or 3.1 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=...
# Llama 3.1 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3.1-8B --tokenizer_path "original" --hf_token=...

# Llama 2 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Llama-2-13b-hf --hf_token=...
# Llama 3.2 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Llama-3.2-3B --hf_token=...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was indeed for Llama 2, not for Llama 3.2.
I think we can remove Llama 2 files if they are not helpful anymore.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can send a seperate PR deprecating Llama2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Can we revert this change for now, as torchtitan doesn't support Llama 3.2 atm?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

```

### Start a training run
Expand All @@ -99,25 +99,9 @@ Llama 3 8B model locally on 8 GPUs
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh
```

## Logging
msaroufim marked this conversation as resolved.
Show resolved Hide resolved

## TensorBoard

To visualize TensorBoard metrics of models trained on a remote server via a local web browser:

1. Make sure `metrics.enable_tensorboard` option is set to true in model training (either from a .toml file or from CLI).

2. Set up SSH tunneling, by running the following from local CLI
```
ssh -L 6006:127.0.0.1:6006 [username]@[hostname]
```

3. Inside the SSH tunnel that logged into the remote server, go to the torchtitan repo, and start the TensorBoard backend
```
tensorboard --logdir=./outputs/tb
```

4. In the local web browser, go to the URL it provides OR to http://localhost:6006/.

We support logging via both TensorBoard and Weights and Biases, all you need to is enable it in your `toml` file or CLI using `enable_tb` or `enable_wandb` respectively. You can learn more [here](docs/metrics.md)

## Multi-Node Training
For training on ParallelCluster/Slurm type configurations, you can use the `multinode_trainer.slurm` file to submit your sbatch job.
Expand Down
36 changes: 36 additions & 0 deletions docs/metrics.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Metrics

We support automatically collecting metrics such as
1. High level system metrics such as MFU, average loss, max loss and words per second along with some
2. Memory metrics to measure max VRAM consumption and the number of OOMs
3 Timing metrics to measure data loading bottlenecks
msaroufim marked this conversation as resolved.
Show resolved Hide resolved

Those metrics can then be visualized in either a TensorBoard or WandDB dashboard

## TensorBoard

To visualize TensorBoard metrics of models trained on a remote server via a local web browser:

1. Make sure `metrics.enable_tensorboard` option is set to true in model training (either from a .toml file or from CLI).

2. Set up SSH tunneling, by running the following from local CLI
```
ssh -L 6006:127.0.0.1:6006 [username]@[hostname]
```

3. Inside the SSH tunnel that logged into the remote server, go to the torchtitan repo, and start the TensorBoard backend
```
tensorboard --logdir=./outputs/tb
```

4. In the local web browser, go to the URL it provides OR to http://localhost:6006/.

## Weights and Biases

Weights and Biases will automatically send metrics to a remote server if you login with `wandb login`

So all you need to do is make sure that `metrics.enable_wandb` is enabled

For an example you can inspect [debug_model.toml](../train_configs/debug_model.toml)

Note that if both W&B and Tensorboard are enabled then we will prioritize W&B.
19 changes: 13 additions & 6 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,16 @@ def __init__(self):
help="How often to log metrics to TensorBoard, in iterations",
)
self.parser.add_argument(
"--metrics.enable_color_printing",
default=False,
"--metrics.enable_tensorboard",
action="store_true",
help="Whether to enable color printing",
default=False,
help="Whether to log metrics to TensorBoard",
)
self.parser.add_argument(
"--metrics.enable_tensorboard",
"--metrics.enable_color_printing",
action="store_true",
help="Whether to log metrics to TensorBoard",
default=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was flipping the enable_color_printing default intended? If so, the action store_true does not fit well anymore. It's now true by default

help="Whether to enable color printing in logs",
)
self.parser.add_argument(
"--metrics.save_tb_folder",
Expand All @@ -139,14 +140,20 @@ def __init__(self):
)
self.parser.add_argument(
"--metrics.rank_0_only",
default=True,
action="store_true",
default=True,
help="""
Whether to save TensorBoard metrics only for rank 0 or for all ranks.
When pipeline_parallel_degree is > 1, this option uses the 0th rank of the last stage pipeline group,
which is the only stage that computes loss metrics.
""",
)
self.parser.add_argument(
"--metrics.enable_wandb",
action="store_true",
default=False,
help="Whether to log metrics to Weights & Biases",
)

# model configs
self.parser.add_argument(
Expand Down
150 changes: 110 additions & 40 deletions torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
from collections import namedtuple
from datetime import datetime
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import torch
from torch.utils.tensorboard import SummaryWriter
Expand All @@ -16,7 +16,6 @@
from torchtitan.parallelisms import ParallelDims
from torchtitan.utils import device_module, device_type


# named tuple for passing device memory stats for logging
DeviceMemStats = namedtuple(
"DeviceMemStats",
Expand Down Expand Up @@ -88,31 +87,68 @@ def reset_peak_stats(self):


def build_device_memory_monitor():
device_memory_monitor = DeviceMemoryMonitor(device_type)
device_memory_monitor = DeviceMemoryMonitor()
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
logger.info(
f"{device_type.upper()} capacity: {device_memory_monitor.device_name} ({device_memory_monitor.device_index}) "
f"{device_type.upper()} capacity: {device_memory_monitor.device_name}"
f"with {device_memory_monitor.device_capacity_gib:.2f}GiB memory"
)

return device_memory_monitor


class MetricLogger:
def __init__(self, log_dir, tag, enable_tb):
class DummyLogger:
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
"""Logger that does nothing, used when logging is disabled."""

def log(self, metrics: Dict[str, Any], step: int) -> None:
pass

def close(self) -> None:
pass


class TensorBoardLogger:
"""Logger implementation for TensorBoard."""

def __init__(self, log_dir: str, tag: Optional[str] = None):
self.tag = tag
self.writer: Optional[SummaryWriter] = None
if enable_tb:
self.writer = SummaryWriter(log_dir, max_queue=1000)
self.writer = SummaryWriter(log_dir, max_queue=1000)
logger.info(f"TensorBoard logging enabled. Logs will be saved at {log_dir}")

def log(self, metrics: Dict[str, Any], step: int) -> None:
for k, v in metrics.items():
tag = k if self.tag is None else f"{self.tag}/{k}"
self.writer.add_scalar(tag, v, step)

def close(self) -> None:
self.writer.close()


class WandBLogger:
"""Logger implementation for Weights & Biases."""

def __init__(self, log_dir: str, tag: Optional[str] = None):
# Import wandb here to avoid startup import
import wandb

self.wandb = wandb
self.tag = tag

self.wandb.init(
project="torchtitan",
dir=log_dir,
)
logger.debug("WandB logging enabled")
msaroufim marked this conversation as resolved.
Show resolved Hide resolved

def log(self, metrics: Dict[str, Any], step: int):
if self.writer is not None:
for k, v in metrics.items():
tag = k if self.tag is None else f"{self.tag}/{k}"
self.writer.add_scalar(tag, v, step)
def log(self, metrics: Dict[str, Any], step: int) -> None:
wandb_metrics = {
(k if self.tag is None else f"{self.tag}/{k}"): v
for k, v in metrics.items()
}
wandb_metrics["step"] = step
self.wandb.log(wandb_metrics)

def close(self):
if self.writer is not None:
self.writer.close()
def close(self) -> None:
if self.wandb.run is not None:
self.wandb.finish()


def _get_metrics_rank(parallel_dims: ParallelDims) -> int:
Expand All @@ -126,35 +162,69 @@ def _get_metrics_rank(parallel_dims: ParallelDims) -> int:
metrics_log_rank = (world_size // pp_size) * (pp_size - 1)
else:
metrics_log_rank = 0

return metrics_log_rank


def build_metric_logger(
job_config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None
):
) -> Union[DummyLogger, TensorBoardLogger, WandBLogger]:
"""
parallel_dims is used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'.
In that case, `_get_metrics_rank` will be used to calculate which rank acts as 'rank 0'. This is
intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline
parallelism is enabled, without forcing logging from all ranks to capture loss information.
Build an appropriate metric logger based on configuration.
"""
metrics_config = job_config.metrics

# Log initial config state
logger.info(
f"Building logger with config: wandb={metrics_config.enable_wandb}, "
f"tensorboard={metrics_config.enable_tensorboard}"
)
msaroufim marked this conversation as resolved.
Show resolved Hide resolved

# Check if any logging backend is enabled
has_logging_enabled = (
metrics_config.enable_tensorboard or metrics_config.enable_wandb
)

# Determine if this rank should log
should_log = has_logging_enabled
if metrics_config.rank_0_only and should_log:
metrics_rank = _get_metrics_rank(parallel_dims)
should_log = torch.distributed.get_rank() == metrics_rank

logger.info(
f"Logging decision: has_logging_enabled={has_logging_enabled}, should_log={should_log}"
)

if not should_log:
logger.info("Returning DummyLogger due to should_log=False")
return DummyLogger()

# Setup logging directory
dump_dir = job_config.job.dump_folder
tb_config = job_config.metrics
save_tb_folder = tb_config.save_tb_folder
# since we don't have run id, use current minute as the identifier
datetime_str = datetime.now().strftime("%Y%m%d-%H%M")
log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str)

enable_tb = tb_config.enable_tensorboard
if enable_tb:
logger.info(
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}"
base_log_dir = os.path.join(
dump_dir, metrics_config.save_tb_folder, datetime.now().strftime("%Y%m%d-%H%M")
)

if not metrics_config.rank_0_only:
base_log_dir = os.path.join(
base_log_dir, f"rank_{torch.distributed.get_rank()}"
)
if tb_config.rank_0_only:
enable_tb = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims)
else:
rank_str = f"rank_{torch.distributed.get_rank()}"
log_dir = os.path.join(log_dir, rank_str)

return MetricLogger(log_dir, tag, enable_tb)
# Create loggers in priority order
if metrics_config.enable_wandb:
logger.info("Attempting to create WandB logger")
try:
return WandBLogger(base_log_dir, tag)
except Exception as e:
if "No module named 'wandb'" in str(e):
logger.error(
"Failed to create WandB logger: No module named 'wandb'. Please install it using 'pip install wandb'."
)
else:
logger.error(f"Failed to create WandB logger: {e}")

if metrics_config.enable_tensorboard:
logger.info("Creating TensorBoard logger")
return TensorBoardLogger(base_log_dir, tag)

logger.info("No loggers enabled, returning DummyLogger")
return DummyLogger()
1 change: 1 addition & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ log_freq = 1
enable_color_printing = true
enable_tensorboard = true
save_tb_folder = "tb"
enable_wandb = false

[model]
name = "llama3"
Expand Down
Loading