diff --git a/classy_vision/models/classy_model.py b/classy_vision/models/classy_model.py index 4c0c8a3775..fb18250ba0 100644 --- a/classy_vision/models/classy_model.py +++ b/classy_vision/models/classy_model.py @@ -407,8 +407,11 @@ def execute_heads(self) -> Dict[str, torch.Tensor]: @property def input_shape(self): - """If implemented, returns expected input tensor shape""" - raise NotImplementedError + """Returns the input shape that the model can accept, excluding the batch dimension. + + By default it returns (3, 224, 224). + """ + return (3, 224, 224) class _ClassyModelAdapter(ClassyModel): diff --git a/classy_vision/models/densenet.py b/classy_vision/models/densenet.py index c702cc081d..33867af32e 100644 --- a/classy_vision/models/densenet.py +++ b/classy_vision/models/densenet.py @@ -274,10 +274,3 @@ def forward(self, x): out = self.features(out) return out - - @property - def input_shape(self): - if self.small_input: - return (3, 32, 32) - else: - return (3, 224, 224) diff --git a/classy_vision/models/mlp.py b/classy_vision/models/mlp.py index 1552b96e40..1ca34b7184 100644 --- a/classy_vision/models/mlp.py +++ b/classy_vision/models/mlp.py @@ -85,7 +85,3 @@ def forward(self, x): out = x.view(batchsize_per_replica, -1) out = self.mlp(out) return out - - @property - def input_shape(self): - return (self._num_inputs,) diff --git a/classy_vision/models/regnet.py b/classy_vision/models/regnet.py index 81a34b1c39..51c4ac3f60 100644 --- a/classy_vision/models/regnet.py +++ b/classy_vision/models/regnet.py @@ -537,10 +537,6 @@ def init_weights(self): m.weight.data.normal_(mean=0.0, std=0.01) m.bias.data.zero_() - @property - def input_shape(self): - return (3, 224, 224) - # Register some "classic" RegNets class _RegNet(RegNet): diff --git a/classy_vision/models/resnext.py b/classy_vision/models/resnext.py index e361548488..1b2c838d44 100644 --- a/classy_vision/models/resnext.py +++ b/classy_vision/models/resnext.py @@ -434,13 +434,6 @@ def forward(self, x): return out - @property - def input_shape(self): - if self.small_input: - return (3, 32, 32) - else: - return (3, 224, 224) - def _convert_model_state(self, state): """Convert model state from the old implementation to the current format.