From c850a42c982ab63cf7f618f595fc69977bbfee32 Mon Sep 17 00:00:00 2001 From: Warley Vital Barbosa Date: Wed, 20 Mar 2024 18:12:08 -0300 Subject: [PATCH] refactor facornet to use pytorch lightning (#71) --- ours/datasets/facornet.py | 47 +++++++ ours/guild-config.yml | 7 + ours/guild.yml | 4 +- ours/models/facornet.py | 186 ++++++++++++++++++++++++++ ours/tasks/facornet.py | 275 +++----------------------------------- 5 files changed, 260 insertions(+), 259 deletions(-) diff --git a/ours/datasets/facornet.py b/ours/datasets/facornet.py index 1e5805c..6e4b7eb 100644 --- a/ours/datasets/facornet.py +++ b/ours/datasets/facornet.py @@ -1,3 +1,9 @@ +from pathlib import Path + +import lightning as pl +from torch.utils.data import DataLoader +from torchvision import transforms + from .fiw import FIW @@ -13,6 +19,47 @@ class FIWFaCoRNet(FIW): def __init__(self, **kwargs): super().__init__(**kwargs) + def __getitem__(self, item): + (img1, img2, labels) = super().__getitem__(item) + # Convert img1 and img2 to BGR - they're tensors (C, H, W) + # img1 = img1[[2, 1, 0], :, :] + # img2 = img2[[2, 1, 0], :, :] + return img1, img2, labels + + +class FaCoRNetDataModule(pl.LightningDataModule): + def __init__(self, batch_size=20, root_dir=".", transform=None): + super().__init__() + self.batch_size = batch_size + self.root_dir = root_dir + self.transform = transform or transforms.Compose([transforms.ToTensor()]) + + def setup(self, stage=None): + if stage == "fit" or stage is None: + self.train_dataset = FIW( + root_dir=self.root_dir, sample_path=Path(FIWFaCoRNet.TRAIN_PAIRS), transform=self.transform + ) + self.val_dataset = FIW( + root_dir=self.root_dir, sample_path=Path(FIWFaCoRNet.VAL_PAIRS_MODEL_SEL), transform=self.transform + ) + if stage == "val" or stage is None: + self.val_dataset = FIW( + root_dir=self.root_dir, sample_path=Path(FIWFaCoRNet.VAL_PAIRS_THRES_SEL), transform=self.transform + ) + if stage == "test" or stage is None: + self.test_dataset = FIW( + root_dir=self.root_dir, sample_path=Path(FIWFaCoRNet.TEST_PAIRS), transform=self.transform + ) + + def train_dataloader(self): + return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True) + + def val_dataloader(self): + return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, pin_memory=True) + + def test_dataloader(self): + return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, pin_memory=True) + if __name__ == "__main__": fiw = FIW(root_dir="../../datasets/", sample_path=FIWFaCoRNet.TRAIN_PAIRS) diff --git a/ours/guild-config.yml b/ours/guild-config.yml index e72ed09..ad64d2b 100644 --- a/ours/guild-config.yml +++ b/ours/guild-config.yml @@ -22,3 +22,10 @@ remotes: user: warley guild-env: ~/.virtualenvs/research guild-home: ~/.guild + rig-2-facor: + type: ssh + description: RIG2 + host: rig2 + user: warley + guild-env: ~/.virtualenvs/facor + guild-home: ~/.guild diff --git a/ours/guild.yml b/ours/guild.yml index d11d8a6..40ec675 100644 --- a/ours/guild.yml +++ b/ours/guild.yml @@ -112,7 +112,7 @@ operations: train: description: Reproduction of Kinship Representation Learning with Face Componential Relation (2023) - main: tasks.facornet train + main: tasks.facornet fit sourcecode: - utils.py - losses.py @@ -132,7 +132,7 @@ rename: data val: description: Reproduction of Kinship Representation Learning with Face Componential Relation (2023) - main: tasks.facornet val + main: tasks.facornet validate sourcecode: - utils.py - losses.py diff --git a/ours/models/facornet.py b/ours/models/facornet.py index 38c4abd..96b4012 100644 --- a/ours/models/facornet.py +++ b/ours/models/facornet.py @@ -1,9 +1,14 @@ from collections import namedtuple from pathlib import Path +import lightning as pl import numpy as np import torch import torch.nn as nn +import torchmetrics as tm +from datasets.utils import Sample +from losses import facornet_contrastive_loss +from models.utils import compute_best_threshold from torch.nn import ( BatchNorm1d, BatchNorm2d, @@ -18,6 +23,8 @@ Sigmoid, ) +# Assuming the necessary imports are done for FaCoR, facornet_contrastive_loss, FIW, and other utilities + HERE = Path(__file__).parent adaface_models = { @@ -636,6 +643,185 @@ def IR_SE_200(input_size): return model +class DynamicThresholdAccuracy(tm.Metric): + def __init__(self, compute_on_step=True, dist_sync_on_step=False): + super().__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step) + self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def __call__(self, preds: torch.Tensor, target: torch.Tensor, threshold: torch.Tensor): + self.update(preds, target, threshold) + return self.compute() + + def update(self, preds: torch.Tensor, target: torch.Tensor, threshold: torch.Tensor): + preds_thresholded = preds >= threshold.unsqueeze( + 1 + ) # Assuming threshold is a 1D tensor with the same batch size as preds + correct = torch.sum(preds_thresholded == target) + self.correct += correct + self.total += target.numel() + + def compute(self): + return self.correct.float() / self.total + + +class CollectPreds(tm.Metric): + def __init__(self, compute_on_step=False, dist_sync_on_step=False): + super().__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step) + + self.add_state("predictions", default=[], dist_reduce_fx=None) + + def update(self, preds: torch.Tensor): + # Convert preds to the same device as the metric state + preds = preds.detach().to(self.predictions[0].device if self.predictions else preds.device) + + # Append current batch predictions to the list of all predictions + self.predictions.append(preds) + + def compute(self): + # Concatenate the list of predictions into a single tensor + return torch.cat(self.predictions, dim=0) + + def reset(self): + # Reset the state (list of predictions) + self.predictions = [] + + +class FaCoRNetLightning(pl.LightningModule): + def __init__(self, lr=1e-4, momentum=0.9, weight_decay=0, weights_path=None, threshold=None, **kwargs): + super().__init__() + self.save_hyperparameters() + self.model = FaCoR() + + self.lr = lr + self.momentum = momentum + self.weight_decay = weight_decay + self.loss_fn = facornet_contrastive_loss + + self.threshold = threshold + + self.similarities = CollectPreds() # Custom metric to collect predictions + self.is_kin_labels = CollectPreds() # Custom metric to collect labels + self.kin_labels = CollectPreds() # Custom metric to collect labels + + # Metrics + self.train_auc = tm.AUROC(task="binary") + self.val_auc = tm.AUROC(task="binary") + self.train_acc = DynamicThresholdAccuracy() + self.val_acc = DynamicThresholdAccuracy() + self.train_acc_kin_relations = tm.MetricCollection( + {f"train/acc_{kin}": DynamicThresholdAccuracy() for kin in Sample.NAME2LABEL.values()} + ) + self.val_acc_kin_relations = tm.MetricCollection( + {f"val/acc_{kin}": DynamicThresholdAccuracy() for kin in Sample.NAME2LABEL.values()} + ) + + def setup(self, stage): + # TODO: use checkpoint callback to load the weights + if self.hparams.weights_path is not None: + map_location = "cuda" if torch.cuda.is_available() else "cpu" + try: + # Load the weights + state_dict = torch.load(self.hparams.weights_path, map_location=map_location) + self.model.load_state_dict(state_dict) + print(f"Loaded weights from {self.hparams.weights_path}") + except FileNotFoundError: + print(f"Failed to load weights from {self.hparams.weights_path}. File does not exist.") + except RuntimeError as e: + print(f"Failed to load weights due to a runtime error: {e}") + + def forward(self, img1, img2): + return self.model([img1, img2]) + + def step(self, batch, stage="train"): + img1, img2, labels = batch + kin_relation, is_kin = labels + f1, f2, att = self.forward(img1, img2) + loss = self.loss_fn(f1, f2, beta=att) + sim = torch.cosine_similarity(f1, f2) + + if stage == "train": + self.__compute_metrics(sim, is_kin.int(), kin_relation, stage) + else: + # Compute best threshold for training or validation + self.similarities(sim) + self.is_kin_labels(is_kin.int()) + self.kin_labels(kin_relation) + + self.log(f"{stage}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + + return loss + + def training_step(self, batch, batch_idx): + return self.step(batch, "train") + + def validation_step(self, batch, batch_idx): + self.step(batch, "val") + + def test_step(self, batch, batch_idx): + self.step(batch, "test") + + def configure_optimizers(self): + optimizer = torch.optim.SGD( + self.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay + ) + return optimizer + + def on_epoch_end(self): + # Calculate the number of samples processed + use_sample = (self.current_epoch + 1) * self.trainer.datamodule.batch_size * self.trainer.limit_train_batches + # Update the dataset's bias or sampling strategy + self.trainer.datamodule.train_dataset.set_bias(use_sample) + # Reset the metrics + self.similarities.reset() + self.is_kin_labels.reset() + self.kin_labels.reset() + + def on_validation_epoch_end(self, outputs): + similarities = self.similarities.compute() + is_kin_labels = self.is_kin_labels.compute() + kin_labels = self.kin_labels.compute() + self.__compute_metrics(similarities, is_kin_labels, kin_labels, stage="val") + + def on_test_epoch_end(self): + similarities = self.similarities.compute() + is_kin_labels = self.is_kin_labels.compute() + kin_labels = self.kin_labels.compute() + self.__compute_metrics(similarities, is_kin_labels, kin_labels, stage="test") + + def __compute_metrics(self, similarities, is_kin_labels, kin_labels, stage="train"): + if stage == "test" and self.threshold is None: + raise ValueError("Threshold must be provided for test stage") + elif stage == "test": + best_threshold = self.threshold + else: # Compute best threshold for training or validation + fpr, tpr, thresholds = tm.functional.roc(similarities, is_kin_labels, task="binary") + best_threshold = compute_best_threshold(tpr, fpr, thresholds) + self.log(f"{stage}/threshold", best_threshold, on_epoch=True, prog_bar=True, logger=True) + + # Log AUC and Accuracy + auc_fn = self.train_auc if stage == "train" else self.val_auc + acc_fn = self.train_acc if stage == "train" else self.val_acc + auc = auc_fn(similarities, is_kin_labels, best_threshold) + acc = acc_fn(similarities, is_kin_labels, best_threshold) + self.log(f"{stage}/auc", auc, on_epoch=True, prog_bar=True, logger=True) + self.log(f"{stage}/acc", acc, on_epoch=True, prog_bar=True, logger=True) + + # Accuracy for each kinship relation + acc_kin_relations = self.train_acc_kin_relations if stage == "train" else self.val_acc_kin_relations + for kin_id in Sample.NAME2LABEL.values(): + mask = kin_labels == kin_id + if torch.any(mask): + acc_kin_relations[f"val/acc_{kin_id}"](similarities[mask], is_kin_labels[mask].int(), best_threshold) + self.log( + f"{stage}/acc_{kin_id}", + acc_kin_relations[f"val/acc_{kin_id}"], + on_epoch=True, + prog_bar=True, + logger=True, + ) + + if __name__ == "__main__": model = IR_101((112, 112)) print(model) diff --git a/ours/tasks/facornet.py b/ours/tasks/facornet.py index a1f5d30..aa4360c 100644 --- a/ours/tasks/facornet.py +++ b/ours/tasks/facornet.py @@ -1,267 +1,28 @@ -from argparse import ArgumentParser -from pathlib import Path +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.cli import LightningCLI +from models.facornet import FaCoRNetLightning -import torch -import torchmetrics as tm -from datasets.utils import Sample -from losses import facornet_contrastive_loss -from models.facornet import FaCoR -from torch.utils.data import DataLoader -from torchvision import transforms -from tqdm import tqdm -from utils import TQDM_BAR_FORMAT, set_seed +from datasets.facornet import FaCoRNetDataModule -from datasets.facornet import FIWFaCoRNet as FIW +class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + # Add custom arguments or modify existing ones + parser.add_argument("--threshold", type=float, default=None, help="Threshold for classification") -def acc_kr_to_str(out, acc_kr): - # Add acc_kr to out - id2name = {v: k for k, v in Sample.NAME2LABEL.items()} - for kin_id, acc in acc_kr.items(): - kr = id2name[kin_id] - out += f" | acc_{kr}: {acc:.6f}" - return out + # Example of adding a callback argument + parser.add_lightning_class_args(ModelCheckpoint, "checkpoint") + def before_fit(self): + # Custom code before the fit starts, for example: + if self.config["threshold"] is not None: + self.model.threshold = self.config["threshold"] -@torch.no_grad() -def predict(model, val_loader, device: int | str = 0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - dataset_size = len(val_loader.dataset) - # Preallocate tensors based on the total dataset size - similarities = torch.zeros(dataset_size, device=device) - y_true = torch.zeros(dataset_size, dtype=torch.uint8, device=device) - y_true_kin_relations = torch.zeros(dataset_size, dtype=torch.uint8, device=device) - current_index = 0 - for img1, img2, labels in tqdm(val_loader, total=len(val_loader), bar_format=TQDM_BAR_FORMAT): - batch_size_current = img1.size(0) # Handle last batch potentially being smaller - img1, img2 = img1.to(device), img2.to(device) - (kin_relation, is_kin) = labels - kin_relation, is_kin = kin_relation.to(device), is_kin.to(device) - - f1, f2, _ = model([img1, img2]) - sim = torch.cosine_similarity(f1, f2) - - # Fill preallocated tensors - similarities[current_index : current_index + batch_size_current] = sim - y_true[current_index : current_index + batch_size_current] = is_kin - y_true_kin_relations[current_index : current_index + batch_size_current] = kin_relation - - current_index += batch_size_current - - return similarities, y_true, y_true_kin_relations - - -def validate(model, dataloader, device=0, threshold=None): - model.eval() - # Compute similarities - similarities, y_true, y_true_kin_relations = predict(model, dataloader) - # Compute metrics - auc = tm.functional.auroc(similarities, y_true, task="binary") - fpr, tpr, thresholds = tm.functional.roc(similarities, y_true, task="binary") - if threshold is None: - # Get the best threshold - maxindex = (tpr - fpr).argmax() - threshold = thresholds[maxindex] - if threshold.isnan().item(): - threshold = 0.01 - else: - threshold = threshold.item() - # Compute acc - acc = tm.functional.accuracy(similarities, y_true, task="binary", threshold=threshold) - # Compute accuracy with respect to kinship relations - acc_kin_relations = {} - for kin_relation in Sample.NAME2LABEL.values(): - mask = y_true_kin_relations == kin_relation - acc_kin_relations[kin_relation] = tm.functional.accuracy( - similarities[mask], y_true[mask], task="binary", threshold=threshold - ) - return auc, threshold, acc, acc_kin_relations - - -def train(args): - - set_seed(args.seed) - - args.output_dir = Path(args.output_dir) - # Create output directory - args.output_dir.mkdir(parents=True, exist_ok=True) - - # Write args to args.yaml - with open(args.output_dir / "args.yaml", "w") as f: - f.write(str(args)) - - # Define transformations for training and validation sets - # Did they mentioned augmentations? - transform = transforms.Compose( - [ - transforms.ToTensor(), - ] - ) - - train_dataset = FIW(root_dir=args.root_dir, sample_path=Path(FIW.TRAIN_PAIRS), transform=transform) - val_model_sel_dataset = FIW(root_dir=args.root_dir, sample_path=Path(FIW.VAL_PAIRS_MODEL_SEL), transform=transform) - - train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True, shuffle=False) - val_model_sel_loader = DataLoader( - val_model_sel_dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True, shuffle=False - ) - - model = FaCoR() - model.to(args.device) - - optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - - total_steps = len(train_loader) - print(f"Total steps: {total_steps}") - global_step = 0 - best_model_auc, _, val_acc, acc_kr = validate(model, val_model_sel_loader) - out = f"epoch: 0 | auc: {best_model_auc:.6f} | acc: {val_acc:.6f}" - # Add acc_kr to out - for kin_relation, acc in acc_kr.items(): - out += f" | acc_{kin_relation}: {acc:.6f}" - print(out) - - for epoch in range(args.num_epoch): - model.train() - loss_epoch = 0.0 - for step, data in enumerate(train_loader): - global_step = step + epoch * args.steps_per_epoch - - image1, image2, labels = data - (kin_relation, is_kin) = labels - - image1 = image1.to(args.device) - image2 = image2.to(args.device) - kin_relation = kin_relation.to(args.device) - is_kin = is_kin.to(args.device) - - x1, x2, att = model([image1, image2]) - loss = facornet_contrastive_loss(x1, x2, beta=att) - - loss_epoch += loss.item() - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - if (step + 1) == args.steps_per_epoch: - break - - use_sample = (epoch + 1) * args.batch_size * args.steps_per_epoch - train_dataset.set_bias(use_sample) - - # Save model checkpoints - auc, _, val_acc, acc_kr = validate(model, val_model_sel_loader) - - if auc > best_model_auc: - best_model_auc = auc - torch.save(model.state_dict(), args.output_dir / "best.pth") - - out = ( - f"epoch: {epoch + 1:>2} | step: {global_step} " - + f"| loss: {loss_epoch / args.steps_per_epoch:.3f} | auc: {auc:.6f} | acc: {val_acc:.6f}" - ) - out = acc_kr_to_str(out, acc_kr) - print(out) - - -def val(args): - - transform = transforms.Compose( - [ - transforms.ToTensor(), - ] - ) - - dataset = FIW(root_dir=args.root_dir, sample_path=Path(FIW.VAL_PAIRS_THRES_SEL), transform=transform) - dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=0, pin_memory=True, shuffle=False) - - model = FaCoR() - model.load_state_dict(torch.load(args.weights)) - model.to(args.device) - - auc, threshold, val_acc, acc_kr = validate(model, dataloader) - out = f"auc: {auc:.6f} | acc: {val_acc:.6f} | threshold: {threshold}" - out = acc_kr_to_str(out, acc_kr) - print(out) - - -def test(args): - - transform = transforms.Compose( - [ - transforms.ToTensor(), - ] - ) - - dataset = FIW(root_dir=args.root_dir, sample_path=Path(FIW.TEST_PAIRS), transform=transform) - dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=0, pin_memory=True, shuffle=False) - - model = FaCoR() - model.load_state_dict(torch.load(args.weights)) - model.to(args.device) - - auc, threshold, val_acc, acc_kr = validate(model, dataloader, threshold=args.threshold) - out = f"auc: {auc:.6f} | acc: {val_acc:.6f} | threshold: {threshold}" - out = acc_kr_to_str(out, acc_kr) - print(out) - - -def create_parser_train(subparsers): - parser = subparsers.add_parser("train", help="Train the model") - parser.add_argument("--root-dir", type=str, required=True) - parser.add_argument("--output-dir", type=str, required=True) - parser.add_argument("--num-epoch", type=int, default=40, help="Number of epochs") - parser.add_argument("--steps-per-epoch", type=int, default=50, help="Steps per epoch") - parser.add_argument("--batch-size", type=int, default=25, help="Batch size") - parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") - parser.add_argument("--momentum", type=float, default=0.9, help="Momentum") - parser.add_argument("--weight-decay", type=float, default=0, help="Weight decay") - parser.add_argument("--device", type=str, default="0", help="Device to use for training") - parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility") - parser.set_defaults(func=train) - - -def create_parser_val(subparsers): - parser = subparsers.add_parser("val", help="Select best threshold for the model") - parser.add_argument("--weights", type=str, required=True) - parser.add_argument("--root-dir", type=str, required=True) - parser.add_argument("--batch-size", type=int, default=100, help="Batch size") - parser.add_argument("--device", type=str, default="0", help="Device to use for training") - parser.set_defaults(func=val) - - -def create_parser_test(subparsers): - parser = subparsers.add_parser("test", help="Test the model") - parser.add_argument("--weights", type=str, required=True) - parser.add_argument("--root-dir", type=str, required=True) - parser.add_argument("--batch-size", type=int, default=100, help="Batch size") - parser.add_argument("--threshold", type=float, required=True) - parser.add_argument("--device", type=str, default="0", help="Device to use for training") - parser.set_defaults(func=test) +def main(): + cli = MyLightningCLI(FaCoRNetLightning, FaCoRNetDataModule) + cli.run() if __name__ == "__main__": - - parser = ArgumentParser(description="Configuration for the FaCoRNet strategy") - subparsers = parser.add_subparsers() - create_parser_train(subparsers) - create_parser_val(subparsers) - create_parser_test(subparsers) - args = parser.parse_args() - - # Necessary for dataloaders? - torch.multiprocessing.set_start_method("spawn") - - print(args) - - if torch.cuda.is_available(): - args.device = torch.device(f"cuda:{args.device}") - current_device = torch.cuda.current_device() - device_name = torch.cuda.get_device_name(current_device) - print(f"Current CUDA Device = {current_device}") - print(f"Device Name = {device_name}") - else: - print("CUDA is not available.") - - args.func(args) + main()