-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
103 lines (86 loc) · 3.39 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from typing import Any, List, Optional, Union
import time
import pathlib
import logging
import os
import time
import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import Tensor, optim, nn
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from pytorch_lightning.core.saving import save_hparams_to_yaml
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.utilities import rank_zero_only
import hydra
from omegaconf import DictConfig, OmegaConf
import torchmetrics
from torchmetrics import ConfusionMatrix
from mlxtend.plotting import plot_confusion_matrix
from torchattacks import PGD, FGSM, FFGSM, APGD, TPGD, CW
from architectures.MKToyNet import MKToyNet
from architectures.WideResNet import WideResNet16
from architectures.ResNet import ResNet18,ResNet50
from architectures.LeNet import LeNet
torch.manual_seed(313) # reproducibility
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
class KhalooeiLoggingLogger(Logger):
def __init__(self,save_dir='logs/',version=f"{time.strftime('%Y%m%d%H%M%S')}"):
# Create a 'logs' directory if it doesn't exist
s_output_dir = os.path.join(save_dir,'lightning_logs',version)
os.makedirs(s_output_dir, exist_ok=True)
# Include a timestamp in the log file name
experiment_name = f"experiment_{version}.log"
# Configure logging with the experiment-specific log file
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s]: %(message)s",
handlers=[
logging.FileHandler(os.path.join(s_output_dir, experiment_name)),
logging.StreamHandler(),
]
)
self.logger = logging.getLogger(__name__)
@property
def name(self):
return "KhalooeiLoggingLogger"
@property
def version(self):
# Return the experiment version, int or str.
pass
@rank_zero_only
def log_metrics(self, metrics, step=None):
# Log metrics to the console and log file
for key, value in metrics.items():
self.logger.info(f"{key}: {value}")
@rank_zero_only
def log_hyperparams(self, params):
# Log hyperparameters to the console and log file
self.logger.info("Hyperparameters:")
for key, value in params.items():
self.logger.info(f"{key}: {value}")
@rank_zero_only
def experiment(self):
# Return the experiment object if available (required by the BaseLogger)
return None
@rank_zero_only
def save(self):
# Return the experiment object if available (required by the BaseLogger)
pass
@rank_zero_only
def finalize(self, status):
# Optional. Any code that needs to be run after training
# finishes goes here
pass
class CustomTimeCallback(Callback):
def on_train_start(self, trainer, lightning_module):
self.start = time.time()
print("Training is starting ...")
def on_train_end(self, trainer, lightning_module):
self.end = time.time()
total_miniutes = (self.end-self.start)/60
print(f"Training is finished. It took {total_miniutes} min.")