Skip to content

Commit

Permalink
Merge branch 'main' into in-loop-gsm
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr authored Feb 4, 2025
2 parents 16d82cf + 8527ffe commit 7dde292
Show file tree
Hide file tree
Showing 9 changed files with 430 additions and 23 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Add GSM8K to in-loop evals (BPB over correct continuation)
- Support for specifying custom dataset objects in the `data` section of the config file.


## [v0.6.0](https://github.com/allenai/OLMo/releases/tag/v0.6.0) - 2024-12-17

Expand Down
164 changes: 164 additions & 0 deletions configs/custom_dataset_example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
run_name: OLMo-7B
seed: 6198
dry_run: false

# Comment out the wandb section if you don't wish to use wandb for logging
# Otherwise update it according to your account info and it's setup on the cluster
#wandb:
# name: ${run_name}
# project: olmo-medium
# group: OLMo-7B

model:
d_model: 4096
n_heads: 32
n_layers: 32
mlp_hidden_size: 16384 # default: 4 * d_model
weight_tying: false
alibi: false
rope: true
flash_attention: false
attention_dropout: 0.0
attention_layer_norm: true
multi_query_attention: false
include_bias: false
block_type: sequential
layer_norm_type: default
layer_norm_with_affine: true
bias_for_layer_norm: false
attention_layer_norm_with_affine: true
activation_type: gelu
residual_dropout: 0.0
embedding_dropout: 0.0
max_sequence_length: 2048
vocab_size: 32100 # depends on the tokenizer
embedding_size: 32128
eos_token_id: 1
pad_token_id: 0
init_device: cuda
init_fn: mitchell

compile: # causes instability on AMD GPUs
fullgraph: false

activation_checkpointing: whole_layer

optimizer:
name: adamw
learning_rate: 3.0e-4
weight_decay: 0.1
betas:
- 0.9
- 0.95
metrics_log_interval: 100

scheduler:
name: cosine_with_warmup
t_warmup: 1000
alpha_f: 0.1
grad_clip_warmup_steps: 1000
grad_clip_warmup_factor: 10.0

tokenizer:
identifier: t5-base
truncate_direction: right

save_folder: ${oc.env:CHECKPOINTS_PATH}
remote_save_folder: null
save_overwrite: false
# Sharded checkpoints (best for restarts)
save_interval: 100
save_num_checkpoints_to_keep: -1
# Unsharded checkpoints (for final storage)
save_interval_unsharded: null
save_num_unsharded_checkpoints_to_keep: -1

load_path: null

max_duration: 515 # 135M tokens (global_train_batch_size * max_sequence_length * max_duration)
global_train_batch_size: 128 # 4 GPUs * 32
device_train_microbatch_size: 32
time_limit: null

precision: amp_bf16

ddp:
grad_sync_mode: batch

distributed_strategy: fsdp

fsdp:
wrapping_strategy: by_block
precision: mixed
sharding_strategy: FULL_SHARD

max_grad_norm: 1.0
max_grad_norm_ratio: null

speed_monitor:
window_size: 20

eval_interval: ${save_interval}
eval_subset_num_batches: 5 # limit how many batches to evaluate on (set to -1 for all batches)
device_eval_batch_size: ${device_train_microbatch_size}

evaluators:

##########################
# Downstream evaluations #
##########################
- label: piqa
type: downstream

- label: hellaswag
type: downstream

- label: winogrande
type: downstream

- label: openbook_qa
type: downstream

# - label: boolq # requires implemention of the pmi_dc matrix
# type: downstream

- label: sciq
type: downstream

- label: arc_easy
type: downstream

# - label: arc_challenge # requires implemention of the pmi_dc matrix
# type: downstream

- label: copa
type: downstream

- label: rte
type: downstream

- label: commitment_bank
type: downstream

- label: mrpc
type: downstream

- label: sst2
type: downstream

data:
pad_direction: right
num_workers: 2
drop_last: true
pin_memory: true
prefetch_factor: 1
persistent_workers: true
timeout: 0
custom_dataset:
name: "<module>.<class>"
args:
arg1: "value1"
arg2: "value2"
arg3: "value3"
collate_config:
input_id_field: "token_ids" # The field in the dataset that contains the input ids
29 changes: 29 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ class DataConfig(BaseConfig):
timeout: int = 0
seed: Optional[int] = None
instance_filter: Optional[InstanceFilterConfig] = None
custom_dataset: Optional[CustomDatasetConfig] = None

@property
def effective_memmap_dtype(self):
Expand All @@ -633,6 +634,34 @@ def effective_memmap_dtype(self):
return dtype


@dataclass
class CustomDatasetCollatorConfig(BaseConfig):
input_id_field: str = "input_ids" #: The field in the dataset items that contains the input token IDs.
attention_mask_field: Optional[str] = None #: The field in the dataset items that contains the attention mask.
attention_bias_field: Optional[str] = None #: The field in the dataset items that contains the attention bias.
label_mask_field: Optional[str] = None #: The field in the dataset items that contains the label mask.
index_field: Optional[str] = None #: The field in the dataset items that contains the index of the item.
instance_mask_field: Optional[str] = None #: The field in the dataset items that contains the instance mask.
doc_lens_field: Optional[str] = None #: The field in the dataset items that contains the document lengths.
metadata_field: Optional[str] = None #: The field in the dataset items that contains the metadata.


@dataclass
class CustomDatasetConfig(BaseConfig):
name: str #: The name of the custom dataset class or function that will be used to load the dataset.
module: Optional[
str
] = None #: The module where the custom dataset class is defined. If not set, the module will be inferred from the class name.
args: Optional[Dict[str, Any]] = None #: The arguments to pass to the custom dataset class or function
collate_fn: Optional[
str
] = None #: The name of the collate function to use for the custom dataset. Assumes the collate function is defined in the same module as the custom dataset class unless specified otherwise using the full object path.
token_field: Optional[str] = None #: The field in the dataset items that contains the tokenized text.
collate_config: Optional[CustomDatasetCollatorConfig] = field(
default_factory=CustomDatasetCollatorConfig
) #: The configuration for the collate function to use for the custom dataset.


class EvaluatorType(StrEnum):
downstream = "downstream"
lm = "lm"
Expand Down
90 changes: 69 additions & 21 deletions olmo/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import importlib
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, cast

Expand All @@ -7,12 +9,15 @@
from ..config import DataConfig, TrainConfig
from ..exceptions import OLMoConfigurationError
from ..torch_util import barrier, get_global_rank, get_world_size
from .collator import DataCollator
from .collator import CustomDatasetDataCollator, DataCollator
from .custom_datasets import build_custom_dataset, extract_module_and_class
from .iterable_dataset import IterableDataset
from .memmap_dataset import MemMapDataset

__all__ = ["MemMapDataset", "DataCollator", "IterableDataset", "build_eval_dataloader", "build_train_dataloader"]

LOGGER = logging.getLogger(__name__)


def build_memmap_dataset(
train_config: TrainConfig, data_config: DataConfig, include_instance_metadata: bool = True
Expand Down Expand Up @@ -48,6 +53,42 @@ def build_memmap_dataset(
)


def build_collator(train_config: TrainConfig) -> DataCollator:
"""Returns a collator for the train dataloader. Either returns the default
collator or a custom collator specified in the train config.
:param train_config: OLMo train config
:raises OLMoConfigurationError: Raises an error if the collate function is not found
:return: Collator for the train dataloader
"""
if train_config.data.custom_dataset:
if train_config.data.custom_dataset.collate_fn:
module, function = extract_module_and_class(train_config.data.custom_dataset.collate_fn)
if module is None:
if train_config.data.custom_dataset.module is None:
module, _ = extract_module_and_class(train_config.data.custom_dataset.name)
else:
module = train_config.data.custom_dataset.module
try:
assert module is not None
collator = getattr(importlib.import_module(module), function)
except AttributeError:
raise OLMoConfigurationError(
f"collate_fn {train_config.data.custom_dataset.collate_fn} not found in {module}. Please specify the full module path of the function."
)
return collator

return CustomDatasetDataCollator(
pad_direction=train_config.data.pad_direction,
pad_token_id=train_config.model.pad_token_id,
**train_config.data.custom_dataset.collate_config.asdict(), # type: ignore
)
else:
return DataCollator(
pad_direction=train_config.data.pad_direction, pad_token_id=train_config.model.pad_token_id
)


def build_eval_dataloader(
train_config: TrainConfig,
data_config: DataConfig,
Expand Down Expand Up @@ -92,12 +133,18 @@ def build_train_dataloader(
include_instance_metadata: bool = False,
) -> DataLoader:
assert train_config.device_train_batch_size is not None
collator = DataCollator(
pad_direction=train_config.data.pad_direction, pad_token_id=train_config.model.pad_token_id
)
dataset = build_memmap_dataset(
train_config, train_config.data, include_instance_metadata=include_instance_metadata
)
seed = train_config.data.seed if train_config.data.seed is not None else train_config.seed
collator = build_collator(train_config)
if train_config.data.custom_dataset:
if train_config.data.paths is not None or train_config.data.datasets is not None:
raise OLMoConfigurationError(
"custom_dataset_class is mutually exclusive with DataConfig.paths and DataConfig.datasets"
)
dataset = build_custom_dataset(train_config)
else:
dataset = build_memmap_dataset(
train_config, train_config.data, include_instance_metadata=include_instance_metadata
)
work_dir = Path(train_config.save_folder) / "train_data"
if get_global_rank() == 0:
if work_dir.is_dir() and not train_config.save_overwrite:
Expand All @@ -106,21 +153,21 @@ def build_train_dataloader(
)
else:
work_dir.mkdir(exist_ok=True, parents=True)
dataset = IterableDataset(
dataset, # type: ignore
train_config.global_train_batch_size,
seed=seed,
epoch=train_config.epoch or 0,
shuffle=True,
drop_last=train_config.data.drop_last,
world_size=world_size,
rank=rank,
fs_local_rank=fs_local_rank,
work_dir=work_dir,
)
barrier()
seed = train_config.data.seed if train_config.data.seed is not None else train_config.seed
return DataLoader(
IterableDataset(
dataset, # type: ignore
train_config.global_train_batch_size,
seed=seed,
epoch=train_config.epoch or 0,
shuffle=True,
drop_last=train_config.data.drop_last,
world_size=world_size,
rank=rank,
fs_local_rank=fs_local_rank,
work_dir=work_dir,
),
out = DataLoader(
dataset,
batch_size=train_config.device_train_batch_size,
drop_last=train_config.data.drop_last,
collate_fn=collator,
Expand All @@ -130,3 +177,4 @@ def build_train_dataloader(
persistent_workers=False if train_config.data.num_workers == 0 else train_config.data.persistent_workers,
timeout=train_config.data.timeout,
)
return out
36 changes: 35 additions & 1 deletion olmo/data/collator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -138,3 +138,37 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
out["metadata"] = all_metadata

return out


@dataclass
class CustomDatasetDataCollator(DataCollator):
input_id_field: str = "input_ids"
attention_mask_field: Optional[str] = None
attention_bias_field: Optional[str] = None
label_mask_field: Optional[str] = None
index_field: Optional[str] = None
instance_mask_field: Optional[str] = None
doc_lens_field: Optional[str] = None
metadata_field: Optional[str] = None

def _relabel_fields(self, items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
return [self._relabel_item(x) for x in items]

def _relabel_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
out = {
"input_ids": item[self.input_id_field],
"attention_mask": item[self.attention_mask_field] if self.attention_mask_field else None,
"attention_bias": item[self.attention_bias_field] if self.attention_bias_field else None,
"label_mask": item[self.label_mask_field] if self.label_mask_field else None,
"index": item[self.index_field] if self.index_field else None,
"instance_mask": item[self.instance_mask_field] if self.instance_mask_field else None,
"metadata": item[self.metadata_field] if self.metadata_field else None,
}
if self.doc_lens_field:
out["doc_lens"] = item.__getitem__(self.doc_lens_field)
return out

def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Dict[str, Any]:
if not isinstance(items[0], torch.Tensor):
items = self._relabel_fields(items)
return super().__call__(items)
Loading

0 comments on commit 7dde292

Please sign in to comment.