-
The reference validation loss and train + validation MSE curves look way smoother than the results I'm seeing with my implementation: It also seems odd that the training loss doesn't appear to be going down at all throughout training (first chart) yet I end up with a test MSE close to 0.2, as noted was what I should expect. I tried lowering the learning rate but that doesn't help. Any thoughts @rasbt? Thank you in advance 🙏
Here's the implementation, for reference: import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import lightning as L
import torchmetrics
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.preprocessing import StandardScaler
class Model(torch.nn.Module):
def __init__(self, num_features: int):
super().__init__()
self.all_layers = torch.nn.Sequential(
# 1st hidden layer
torch.nn.Linear(num_features, 50),
torch.nn.ReLU(),
# 2nd hidden layer
torch.nn.Linear(50, 25),
torch.nn.ReLU(),
# output layer
torch.nn.Linear(25 , 1),
)
def forward(self, x):
return self.all_layers(x).flatten()
class LightningModel(L.LightningModule):
def __init__(self, model: torch.nn.Module, learning_rate: float):
super().__init__()
self.model = model
self.learning_rate = learning_rate
self.train_metric = LightningModel.build_metric()
self.val_metric = LightningModel.build_metric()
self.test_metric = LightningModel.build_metric()
def _compute_loss(self, batch: torch.tensor) -> tuple[torch.tensor, torch.tensor]:
X, y_true = batch
y_hat = self(X)
loss = F.mse_loss(y_hat, y_true)
return y_true, y_hat, loss
def forward(self, X: torch.tensor) -> torch.tensor:
return self.model(X)
def training_step(self, batch: torch.tensor, batch_index: int) -> torch.tensor:
y_hat, y_true, loss = self._compute_loss(batch)
self.log("train_loss", loss)
# Compute accuracy for the batch
self.train_metric(y_hat, y_true)
# Computes training accuracy after each batch, report
# every epoch
self.log(
"train_mse", self.train_metric, prog_bar=True, on_epoch=True, on_step=False
)
return loss
def validation_step(self, batch: torch.tensor, batch_index: int) -> torch.tensor:
y_true, y_hat, loss = self._compute_loss(batch)
self.log("valid_loss", loss, prog_bar=True)
# Compute accuracy for the batch
self.val_metric(y_hat, y_true)
# Report after each epoch
self.log(
"valid_mse", self.val_metric, prog_bar=True,
)
return loss
def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.SGD(
self.parameters(),
lr=self.learning_rate,
)
def test_step(self, batch, batch_idx):
y_true, y_hat, _ = self._compute_loss(batch)
self.test_metric(y_hat, y_true)
self.log("mse", self.test_metric)
@staticmethod
def build_metric() -> torchmetrics.Metric:
return torchmetrics.MeanSquaredError()
class AmesHousingDataset(Dataset):
def __init__(self, csv_path, transform=None):
df = pd.read_csv(csv_path)
columns = ['Overall Qual', 'Overall Cond', 'Gr Liv Area',
'Central Air', 'Total Bsmt SF', 'SalePrice']
df = pd.read_csv(csv_path,
usecols=columns)
#df['Central Air'] = df['Central Air'].map({'N': 0, 'Y': 1})
df = df.dropna(axis=0)
X = df[['Overall Qual',
'Gr Liv Area',
'Total Bsmt SF']].values
y = df['SalePrice'].values
sc_x = StandardScaler()
sc_y = StandardScaler()
X_std = sc_x.fit_transform(X)
y_std = sc_y.fit_transform(y[:, np.newaxis]).flatten()
self.x = torch.tensor(X_std, dtype=torch.float)
self.y = torch.tensor(y_std, dtype=torch.float).flatten()
def __getitem__(self, index):
return self.x[index], self.y[index]
def __len__(self):
return self.y.shape[0]
class AmesHousingDataModule(L.LightningDataModule):
def __init__(self,
csv_path='http://jse.amstat.org/v19n3/decock/AmesHousing.txt',
base_path="../data",
batch_size=32):
super().__init__()
self.base_path = Path(base_path)
self.csv_path = csv_path
self.batch_size = batch_size
def prepare_data(self):
if not os.path.exists(self.base_path / 'AmesHousing.txt'):
df = pd.read_csv(self.csv_path, sep="\t")
df.to_csv(self.base_path / 'AmesHousing.txt', index=None)
def setup(self, stage: str):
all_data = AmesHousingDataset(csv_path=self.base_path / 'AmesHousing.txt')
temp, self.val = random_split(all_data, [2500, 429],
torch.Generator().manual_seed(1))
self.train, self.test = random_split(temp, [2000, 500],
torch.Generator().manual_seed(1))
def train_dataloader(self):
return DataLoader(
self.train, batch_size=self.batch_size,
shuffle=True, drop_last=True)
def val_dataloader(self):
return DataLoader(self.val, batch_size=self.batch_size, shuffle=False)
def test_dataloader(self):
return DataLoader(self.test, batch_size=self.batch_size, shuffle=False)
class TrainingProcess:
def __init__(self, **trainer_kws):
self.dm = AmesHousingDataModule(
batch_size=trainer_kws.pop("batch_size", 32),
csv_path=trainer_kws.pop("csv_path"),
base_path=trainer_kws.pop("base_path", ".")
)
trainer_kws.setdefault("max_epochs", 10)
trainer_kws.setdefault("accelerator", "auto")
trainer_kws.setdefault("devices", "auto")
self.pytorch_model = Model(num_features=3)
self.lightning_model = LightningModel(model=self.pytorch_model, learning_rate=0.05)
self.trainer = L.Trainer(
**trainer_kws,
)
def run(self):
self.trainer.fit(
model=self.lightning_model,
datamodule=self.dm
)
train_acc = self.trainer.validate(dataloaders=self.dm.train_dataloader())[0]["valid_mse"]
val_acc = self.trainer.validate(datamodule=self.dm)[0]["valid_mse"]
test_acc = self.trainer.test(datamodule=self.dm)[0]["mse"]
print(
f"Train MSE {train_acc*100:.2f}%"
f" | Val MSE {val_acc*100:.2f}%"
f" | Test MSE {test_acc*100:.2f}%"
)
# PATH = "lightning.pt"
# torch.save(self.pytorch_model.state_dict(), PATH)
if __name__ == "__main__":
process = TrainingProcess(max_epochs=30, base_path=".", csv_path="data/ames-housing.csv")
process.run() |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
It doesn't look too bad but yeah it's weird that it's different. Are you using the exact same code or did you make changes? Next to the learning rate, which you already checked, the batch size could also be a potential reason. |
Beta Was this translation helpful? Give feedback.
It doesn't look too bad but yeah it's weird that it's different. Are you using the exact same code or did you make changes? Next to the learning rate, which you already checked, the batch size could also be a potential reason.