Skip to content

Commit

Permalink
Strip annotated when getting submodels during CLI parsing. (pydantic#490
Browse files Browse the repository at this point in the history
)
  • Loading branch information
kschwab authored Dec 2, 2024
1 parent a0924bc commit a903697
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,8 +1434,8 @@ def _get_sub_models(self, model: type[BaseModel], field_name: str, field_info: F
raise SettingsError(f'CliSubCommand is not outermost annotation for {model.__name__}.{field_name}')
elif _annotation_contains_types(type_, (_CliPositionalArg,), is_include_origin=False):
raise SettingsError(f'CliPositionalArg is not outermost annotation for {model.__name__}.{field_name}')
if is_model_class(type_) or is_pydantic_dataclass(type_):
sub_models.append(type_) # type: ignore
if is_model_class(_strip_annotated(type_)) or is_pydantic_dataclass(_strip_annotated(type_)):
sub_models.append(_strip_annotated(type_))
return sub_models

def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, field_info: FieldInfo) -> None:
Expand Down
24 changes: 24 additions & 0 deletions tests/test_source_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
BaseModel,
ConfigDict,
DirectoryPath,
Discriminator,
Field,
Tag,
ValidationError,
)
from pydantic import (
Expand Down Expand Up @@ -2268,3 +2270,25 @@ class MySettings(BaseSettings):
CliApp.run(
MySettings, cli_args=['--bac', 'cli abbrev are invalid for internal parser'], cli_exit_on_error=False
)


def test_cli_submodels_strip_annotated():
class PolyA(BaseModel):
a: int = 1
type: Literal['a'] = 'a'

class PolyB(BaseModel):
b: str = '2'
type: Literal['b'] = 'b'

def _get_type(model: Union[BaseModel, Dict]) -> str:
if isinstance(model, dict):
return model.get('type', 'a')
return model.type # type: ignore

Poly = Annotated[Union[Annotated[PolyA, Tag('a')], Annotated[PolyB, Tag('b')]], Discriminator(_get_type)]

class WithUnion(BaseSettings):
poly: Poly

assert CliApp.run(WithUnion, ['--poly.type=a']).model_dump() == {'poly': {'a': 1, 'type': 'a'}}

0 comments on commit a903697

Please sign in to comment.