-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 03afd4f
Showing
9 changed files
with
654 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
.idea | ||
*checkpoints | ||
/scratches | ||
/lightning_logs | ||
/Data | ||
/exports | ||
*/__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# smartfashion | ||
Prediction of apparel product attributes based on photographs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from torchvision import models | ||
import torch | ||
import pytorch_lightning as pl | ||
from sklearn import metrics | ||
from argparse import ArgumentParser | ||
|
||
|
||
class AttrPred_Resnet50(pl.LightningModule): | ||
def __init__(self, n_attributes, prediction_threshold=.0, **kwargs): | ||
super(AttrPred_Resnet50, self).__init__() | ||
self.predictor = models.resnet50(pretrained=True) | ||
self.predictor.fc = torch.nn.Linear(in_features=2048, out_features=n_attributes) | ||
self.prediction_threshold = prediction_threshold | ||
self.save_hyperparameters() | ||
self.checkpoint = pl.callbacks.ModelCheckpoint(filename='{epoch:02d}-{avg_f1_score:.3}.ckpt', | ||
monitor="avg_f1_score", save_top_k=-1, mode="max") | ||
|
||
@staticmethod | ||
def add_model_specific_args(parent_parser): | ||
parser = ArgumentParser(parents=[parent_parser], add_help=False) | ||
parser.add_argument("--prediction_threshold", type=float, default=0., help="Threshold to " | ||
"trigger attribute prediction") | ||
return parser | ||
|
||
def forward(self, x, *args, **kwargs): | ||
return self.predictor(x) | ||
|
||
def training_step(self, batch, batch_id, *args, **kwargs): | ||
x, y = batch | ||
scores = self(x) | ||
loss = torch.nn.functional.binary_cross_entropy_with_logits(scores, y, reduction="mean") | ||
self.log("train_loss", loss, on_step=True) | ||
self.log("avg_train_loss", loss, on_epoch=True) | ||
return loss | ||
|
||
def configure_optimizers(self): | ||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | ||
return optimizer | ||
|
||
def validation_step(self, batch, batch_id, *args, **kwargs): | ||
x, y = batch | ||
scores = self(x) | ||
y_hat = (scores > self.prediction_threshold).float() | ||
val_loss = torch.nn.functional.binary_cross_entropy_with_logits(y_hat, y, reduction="mean") | ||
f1_score = torch.tensor(metrics.f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="micro")) | ||
self.log("avg_f1_score", f1_score, prog_bar=True, on_epoch=True) | ||
self.log("avg_val_loss", val_loss, prog_bar=True, on_epoch=True) | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
model = AttrPred_Resnet50(228, prediction_threshold=.0) | ||
trainer = pl.Trainer() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from torchvision import models | ||
import torch | ||
import pytorch_lightning as pl | ||
from sklearn import metrics | ||
from argparse import ArgumentParser | ||
|
||
|
||
class DebugNet(pl.LightningModule): | ||
def __init__(self, n_attributes, prediction_threshold=.0, input_dim=(512, 512), **kwargs): | ||
""" | ||
A very small densely connected prediction network for local debugging of the training pipeline without having | ||
a gpu at hand. Not meant for actual inference! | ||
:param n_attributes: number of attributes to predict | ||
:param prediction_threshold: threshold score above which attribute is counted as predicted | ||
:param input_dim: | ||
:param kwargs: buffers hyperparameters for saving (see self.save_hyperparameters) | ||
""" | ||
super(DebugNet, self).__init__() | ||
self.input_dim = input_dim | ||
self.predictor = torch.nn.Linear(3*input_dim[0]*input_dim[1], n_attributes) | ||
self.prediction_threshold = prediction_threshold | ||
self.save_hyperparameters() | ||
|
||
@staticmethod | ||
def add_model_specific_args(parent_parser): | ||
parser = ArgumentParser(parents=[parent_parser], add_help=False) | ||
parser.add_argument("--prediction_threshold", type=float, default=0., help="Threshold to " | ||
"trigger attribute prediction") | ||
return parser | ||
|
||
def forward(self, x, *args, **kwargs): | ||
return self.predictor(x.view(-1, self.input_dim[0]*self.input_dim[1]*3)) | ||
|
||
def training_step(self, batch, batch_id, *args, **kwargs): | ||
x, y = batch | ||
scores = self(x) | ||
loss = torch.nn.functional.binary_cross_entropy_with_logits(scores, y, reduction="mean") | ||
return loss | ||
|
||
def configure_optimizers(self): | ||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | ||
return optimizer | ||
|
||
def validation_step(self, batch, batch_id, *args, **kwargs): | ||
x, y = batch | ||
scores = self(x) | ||
y_hat = (scores > self.prediction_threshold).float() | ||
val_loss = torch.nn.functional.binary_cross_entropy_with_logits(y_hat, y, reduction="mean") | ||
f1_score = metrics.f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="micro") | ||
|
||
return {"val_loss": val_loss, "f1_score": f1_score} | ||
|
||
|
||
if __name__ == "__main__": | ||
model = DebugNet(228, prediction_threshold=.0) | ||
trainer = pl.Trainer() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from AttrPredModel import AttrPred_Resnet50 | ||
import argparse | ||
import torch | ||
from torchsummary import summary | ||
import sys | ||
|
||
|
||
def export2sas(argstring=None): | ||
parser = argparse.ArgumentParser("Extracting a trained Attribute Prediction Model to ONNX format") | ||
parser.add_argument("-c", type=str, help="path to checkpoint file", required=True) | ||
parser.add_argument("-o", type=str, help="path to output file", required=True) | ||
|
||
if argstring: | ||
args = parser.parse_args(argstring) | ||
else: | ||
args = parser.parse_args() | ||
|
||
checkpoint_path = args.c | ||
output_path = args.o | ||
|
||
model = AttrPred_Resnet50.load_from_checkpoint(checkpoint_path) | ||
model.cpu() | ||
# print(summary(model, input_size=(3, 512, 512), device="cpu")) | ||
|
||
input_sample = torch.rand((1, 3, 512, 512)) | ||
# model.to_onnx(output_path, input_sample, export_params=True, input_names=["input"], output_names=["output"], | ||
# dynamic_axes={"input": {0: 'batch_size'}, "output": {0: "batch_size"}}) | ||
model.to_onnx(output_path, input_sample, export_params=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
if len(sys.argv) <= 1: | ||
source_file = "../lightning_logs/version_7064773/checkpoints/epoch=3.ckpt" | ||
outfile = "../exports/attr_pred_model.onnx" | ||
export2sas(f"-c {source_file} -o {outfile}") | ||
else: | ||
export2sas() |
Oops, something went wrong.