Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 committed Oct 24, 2024
1 parent e746a4b commit 9e25c6a
Showing 1 changed file with 7 additions and 49 deletions.
56 changes: 7 additions & 49 deletions dolomite_engine/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .data import get_megatron_gpt_dataloaders, get_next_batch
from .distributed import wrap_model_container_for_distributed_training
from .enums import FP8Backend, Mode, TuningMethod
from .model_wrapper import ModelWrapperForPretraining, get_model_container
from .model_wrapper import get_model_container
from .optimization import get_optimizer_container, get_scheduler_container
from .train_utils import all_reduce_metrics_tracker, get_model_tflops, get_torch_profiler, track_metrics, train_step
from .utils import (
Expand Down Expand Up @@ -213,7 +213,7 @@ def train(
metrics_tracker = MetricsTrackingDict({})

if eval_during_training and (global_step % eval_interval == 0 or global_step == num_training_steps):
evaluate(val_dataloaders, model, global_step, experiments_tracker, eval_steps, group_names)
evaluate(val_dataloaders, model_container, global_step, experiments_tracker, eval_steps, group_names)

if global_step % save_interval == 0 or global_step == num_training_steps:
save_checkpoint(
Expand All @@ -233,7 +233,7 @@ def train(
steps_since_start_time = 0

if eval_during_training:
evaluate(test_dataloaders, model, global_step, experiments_tracker, eval_steps, group_names)
evaluate(test_dataloaders, model_container, global_step, experiments_tracker, eval_steps, group_names)

if torch_profiler is not None:
torch_profiler.__exit__()
Expand Down Expand Up @@ -263,54 +263,12 @@ def evaluate(
"""

if ProcessGroupManager.get_pipeline_parallel_world_size() > 1:
metrics_tracker = _train_step_with_pipeline_parallel(
model_container=model_container,
pipeline_schedule=pipeline_schedule,
optimizer_container=optimizer_container,
lr_scheduler_container=lr_scheduler_container,
train_dataloader=train_dataloader,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_clipping=gradient_clipping,
batch_size=batch_size,
sequence_length=sequence_length,
)
else:
assert len(model_container) == 1

metrics_tracker = _evaluate_without_pipeline_parallel(
val_dataloaders=val_dataloaders,
model=model_container[0],
global_step=global_step,
experiments_tracker=experiments_tracker,
eval_steps=eval_steps,
group_names=group_names,
raise NotImplementedError(
"pipeline parallel doesn't support evaluation yet, pass eval_during_training = false"
)

return metrics_tracker


@torch.no_grad()
def _evaluate_without_pipeline_parallel(
val_dataloaders: list[DataLoader],
model: ModelWrapperForPretraining,
global_step: int,
experiments_tracker: ExperimentsTracker,
eval_steps: int,
group_names: list[str],
) -> float:
"""main validation loop for the program
Args:
val_dataloaders (list[DataLoader]): list of validation dataloaders
model (ModelWrapperForPretraining): model
global_step (int): global step during training
experiments_tracker (ExperimentsTracker): metrics tracker
eval_steps (int): number of steps to run eval for
group_names (list[str]): names of the datasets in validation/test group
Returns:
MetricsTrackingDict: metrics tracker
"""
assert len(model_container) == 1
model = model_container[0]

tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size()

Expand Down

0 comments on commit 9e25c6a

Please sign in to comment.