From 5be4b439de01434ef53b1d71a73e77fc358baa5d Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 28 Nov 2024 01:31:48 +0100 Subject: [PATCH] fixed stage name --- modelconverter/utils/config.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/modelconverter/utils/config.py b/modelconverter/utils/config.py index 08a1bcc..f970cf6 100644 --- a/modelconverter/utils/config.py +++ b/modelconverter/utils/config.py @@ -447,9 +447,8 @@ def _validate_stages(cls, data: Dict[str, Any]) -> Dict[str, Any]: else: extra = {} for key in list(data.keys()): - if key not in cls.__fields__: + if key not in cls.model_fields: extra[key] = data.pop(key) - for stage in data["stages"].values(): for key, value in extra.items(): if key not in stage: @@ -459,7 +458,7 @@ def _validate_stages(cls, data: Dict[str, Any]) -> Dict[str, Any]: @model_validator(mode="after") def _validate_single_stage_name(self) -> Self: """Changes the default 'default_stage' name to the name of the input model.""" - if len(self.stages) == 1: + if len(self.stages) == 1 and "default_stage" in self.stages: stage = next(iter(self.stages.values())) model_name = stage.input_model.stem self.stages = {model_name: stage}