-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
06c4623
commit 95ca65b
Showing
1 changed file
with
95 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import argparse | ||
from data_loader import Train_Test_Split, ESPDataset, ESPDataModule | ||
import optuna | ||
from optuna.integration import PyTorchLightningPruningCallback | ||
from pytorch_lightning.callbacks import ModelCheckpoint | ||
from env import * | ||
from train import ESPFailureModel | ||
from data_loader import Train_Test_Split, ESPDataset, ESPDataModule | ||
import pytorch_lightning as pl | ||
from pytorch_lightning.loggers import TensorBoardLogger | ||
|
||
|
||
|
||
def objective(trial, seed, split): | ||
hidden_size = trial.suggest_categorical('hidden_size', [32, 64, 128, 256]) | ||
dropout = trial.suggest_float('dropout', 0.0, 0.5, step=0.1) | ||
num_stack_layers = trial.suggest_int('num_stack_layers', 1, 3) | ||
num_epochs = trial.suggest_categorical('num_epochs', [150, 200, 250, 300]) | ||
learning_rate = trial.suggest_categorical('learning_rate', [1e-3, 1e-4, 1e-5]) | ||
num_layers = trial.suggest_int('num_layers', 1, 3) | ||
batch_size = trial.suggest_categorical('batch_size', [32, 64, 128, 256]) | ||
|
||
# set seed for reproducibility | ||
pl.seed_everything(seed=seed) | ||
|
||
tts = Train_Test_Split(f"{DAILY_OUTPUT_FOLDER}_{SLIDE_N}", split=split) | ||
data_paths = tts.split_data() | ||
|
||
# Create the dataloaders | ||
data_module = ESPDataModule(train_paths=data_paths["train"], | ||
val_paths=data_paths["val"], | ||
test_paths=data_paths["test"], | ||
batch_size=batch_size) | ||
|
||
# Load a single file to get the model dimensions | ||
single_batch = next(iter(ESPDataset(data_paths["val"][:1]))) | ||
n_features = single_batch["features"].shape[-1] | ||
n_classes= single_batch["labels"].shape[-1] | ||
|
||
# Initialize the model | ||
model = ESPFailureModel(n_features=n_features, | ||
n_classes=n_classes, | ||
lr=learning_rate, | ||
dropout=dropout, | ||
hidden_size=hidden_size, | ||
num_stack_layers=num_stack_layers, | ||
n_layers=num_layers) | ||
|
||
# Define the model callbacks | ||
checkpoint_call_back = ModelCheckpoint( | ||
dirpath=f"checkpoints/{trial.number}", | ||
filename="best-chckpt", | ||
save_top_k=1, | ||
verbose=True, | ||
monitor="val_loss", | ||
mode="min" | ||
) | ||
|
||
logger = TensorBoardLogger(save_dir="lightning_logs", name="JTK_Challenge") | ||
|
||
trainer = pl.Trainer(logger=logger, | ||
callbacks=checkpoint_call_back, | ||
max_epochs=num_epochs, | ||
deterministic=True, | ||
enable_progress_bar=True) | ||
|
||
model.save_hyperparameters({"hidden_dim": hidden_size, | ||
"learning_rate": learning_rate, | ||
"dropout": dropout, | ||
"num_stack_layers": num_stack_layers, | ||
"num_layers": num_layers, | ||
"seed": seed, | ||
"split": split, | ||
"batch_size": batch_size, | ||
"device":trainer.accelerator}) | ||
|
||
trainer.fit(model, data_module) | ||
|
||
return trainer.callback_metrics["val_loss"].item() | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Hyperparameter search for ESP Failure model") | ||
parser.add_argument("--seed", type=int, default=42) | ||
parser.add_argument("--split", type=float, default=0.8) | ||
parser.add_argument("--n_trials", type=int, default=300) | ||
args = parser.parse_args() | ||
|
||
study = optuna.create_study(direction='minimize') | ||
study.optimize(lambda trial: objective(trial, args.seed, args.split), n_trials=args.n_trials) | ||
|
||
best_params = study.best_params | ||
print("Best hyperparameters:", best_params) | ||
print("Best value:", study.best_value) |