Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
malteprinzler committed Nov 15, 2020
0 parents commit 03afd4f
Show file tree
Hide file tree
Showing 9 changed files with 654 additions and 0 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.idea
*checkpoints
/scratches
/lightning_logs
/Data
/exports
*/__pycache__
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# smartfashion
Prediction of apparel product attributes based on photographs
53 changes: 53 additions & 0 deletions python_code/AttrPredModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from torchvision import models
import torch
import pytorch_lightning as pl
from sklearn import metrics
from argparse import ArgumentParser


class AttrPred_Resnet50(pl.LightningModule):
def __init__(self, n_attributes, prediction_threshold=.0, **kwargs):
super(AttrPred_Resnet50, self).__init__()
self.predictor = models.resnet50(pretrained=True)
self.predictor.fc = torch.nn.Linear(in_features=2048, out_features=n_attributes)
self.prediction_threshold = prediction_threshold
self.save_hyperparameters()
self.checkpoint = pl.callbacks.ModelCheckpoint(filename='{epoch:02d}-{avg_f1_score:.3}.ckpt',
monitor="avg_f1_score", save_top_k=-1, mode="max")

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--prediction_threshold", type=float, default=0., help="Threshold to "
"trigger attribute prediction")
return parser

def forward(self, x, *args, **kwargs):
return self.predictor(x)

def training_step(self, batch, batch_id, *args, **kwargs):
x, y = batch
scores = self(x)
loss = torch.nn.functional.binary_cross_entropy_with_logits(scores, y, reduction="mean")
self.log("train_loss", loss, on_step=True)
self.log("avg_train_loss", loss, on_epoch=True)
return loss

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer

def validation_step(self, batch, batch_id, *args, **kwargs):
x, y = batch
scores = self(x)
y_hat = (scores > self.prediction_threshold).float()
val_loss = torch.nn.functional.binary_cross_entropy_with_logits(y_hat, y, reduction="mean")
f1_score = torch.tensor(metrics.f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="micro"))
self.log("avg_f1_score", f1_score, prog_bar=True, on_epoch=True)
self.log("avg_val_loss", val_loss, prog_bar=True, on_epoch=True)



if __name__ == "__main__":
model = AttrPred_Resnet50(228, prediction_threshold=.0)
trainer = pl.Trainer()
56 changes: 56 additions & 0 deletions python_code/DebugNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from torchvision import models
import torch
import pytorch_lightning as pl
from sklearn import metrics
from argparse import ArgumentParser


class DebugNet(pl.LightningModule):
def __init__(self, n_attributes, prediction_threshold=.0, input_dim=(512, 512), **kwargs):
"""
A very small densely connected prediction network for local debugging of the training pipeline without having
a gpu at hand. Not meant for actual inference!
:param n_attributes: number of attributes to predict
:param prediction_threshold: threshold score above which attribute is counted as predicted
:param input_dim:
:param kwargs: buffers hyperparameters for saving (see self.save_hyperparameters)
"""
super(DebugNet, self).__init__()
self.input_dim = input_dim
self.predictor = torch.nn.Linear(3*input_dim[0]*input_dim[1], n_attributes)
self.prediction_threshold = prediction_threshold
self.save_hyperparameters()

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--prediction_threshold", type=float, default=0., help="Threshold to "
"trigger attribute prediction")
return parser

def forward(self, x, *args, **kwargs):
return self.predictor(x.view(-1, self.input_dim[0]*self.input_dim[1]*3))

def training_step(self, batch, batch_id, *args, **kwargs):
x, y = batch
scores = self(x)
loss = torch.nn.functional.binary_cross_entropy_with_logits(scores, y, reduction="mean")
return loss

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer

def validation_step(self, batch, batch_id, *args, **kwargs):
x, y = batch
scores = self(x)
y_hat = (scores > self.prediction_threshold).float()
val_loss = torch.nn.functional.binary_cross_entropy_with_logits(y_hat, y, reduction="mean")
f1_score = metrics.f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="micro")

return {"val_loss": val_loss, "f1_score": f1_score}


if __name__ == "__main__":
model = DebugNet(228, prediction_threshold=.0)
trainer = pl.Trainer()
Empty file added python_code/__init__.py
Empty file.
37 changes: 37 additions & 0 deletions python_code/export_checkpoint_to_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from AttrPredModel import AttrPred_Resnet50
import argparse
import torch
from torchsummary import summary
import sys


def export2sas(argstring=None):
parser = argparse.ArgumentParser("Extracting a trained Attribute Prediction Model to ONNX format")
parser.add_argument("-c", type=str, help="path to checkpoint file", required=True)
parser.add_argument("-o", type=str, help="path to output file", required=True)

if argstring:
args = parser.parse_args(argstring)
else:
args = parser.parse_args()

checkpoint_path = args.c
output_path = args.o

model = AttrPred_Resnet50.load_from_checkpoint(checkpoint_path)
model.cpu()
# print(summary(model, input_size=(3, 512, 512), device="cpu"))

input_sample = torch.rand((1, 3, 512, 512))
# model.to_onnx(output_path, input_sample, export_params=True, input_names=["input"], output_names=["output"],
# dynamic_axes={"input": {0: 'batch_size'}, "output": {0: "batch_size"}})
model.to_onnx(output_path, input_sample, export_params=True)


if __name__ == "__main__":
if len(sys.argv) <= 1:
source_file = "../lightning_logs/version_7064773/checkpoints/epoch=3.ckpt"
outfile = "../exports/attr_pred_model.onnx"
export2sas(f"-c {source_file} -o {outfile}")
else:
export2sas()
Loading

0 comments on commit 03afd4f

Please sign in to comment.