Skip to content

Commit

Permalink
order
Browse files Browse the repository at this point in the history
  • Loading branch information
wanchaol committed Apr 25, 2024
1 parent 37f59c5 commit ab787b3
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,6 @@ def loss_fn(pred, labels):

metric_logger = build_metric_logger(job_config)

train_state = TrainState()

if job_config.training.compile:
if (
job_config.activation_checkpoint.mode == "selective"
Expand All @@ -232,6 +230,8 @@ def loss_fn(pred, labels):
)
logger.info(f"Compiling each TransformerBlock with torch.compile")

train_state = TrainState()

# train loop
model.train()

Expand Down

0 comments on commit ab787b3

Please sign in to comment.