diff --git a/project/lit_image_classifier.py b/project/lit_image_classifier.py index 1296a3f..e1f7005 100644 --- a/project/lit_image_classifier.py +++ b/project/lit_image_classifier.py @@ -25,7 +25,9 @@ def forward(self, x): class LitClassifier(pl.LightningModule): def __init__(self, backbone, learning_rate=1e-3): super().__init__() - self.save_hyperparameters() + # It's recomended to specify hyperparameters when using a backbone model. + # Specifically, avoid saving backbone model + self.save_hyperparameters(learning_rate) self.backbone = backbone def forward(self, x):