diff --git a/docs/index.md b/docs/index.md index e4e0d26..fd43bab 100644 --- a/docs/index.md +++ b/docs/index.md @@ -842,9 +842,10 @@ print(User().model_dump()) ### Subcommands and Positional Arguments -Subcommands and positional arguments are expressed using the `CliSubCommand` and `CliPositionalArg` annotations. These -annotations can only be applied to required fields (i.e. fields that do not have a default value). Furthermore, -subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses `dataclass`. +Subcommands and positional arguments are expressed using the `CliSubCommand` and `CliPositionalArg` annotations. The +subcommand annotation can only be applied to required fields (i.e. fields that do not have a default value). +Furthermore, subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses +`dataclass`. Parsed subcommands can be retrieved from model instances using the `get_subcommand` utility function. If a subcommand is not required, set the `is_required` flag to `False` to disable raising an error if no subcommand is found. @@ -1284,6 +1285,9 @@ However, if your use case [aligns more with #2](#command-line-support), using Py likely want required fields to be _strictly required at the CLI_. We can enable this behavior by using `cli_enforce_required`. +!!! note + A required `CliPositionalArg` field is always strictly required (enforced) at the CLI. + ```py import os import sys diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 5e64164..584c7cd 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -1333,7 +1333,11 @@ def _load_env_vars( if subcommand_dest not in selected_subcommands: parsed_args[subcommand_dest] = self.cli_parse_none_str - parsed_args = {key: val for key, val in parsed_args.items() if not key.endswith(':subcommand')} + parsed_args = { + key: val + for key, val in parsed_args.items() + if not key.endswith(':subcommand') and val is not PydanticUndefined + } if selected_subcommands: last_selected_subcommand = max(selected_subcommands, key=len) if not any(field_name for field_name in parsed_args.keys() if f'{last_selected_subcommand}.' in field_name): @@ -1494,6 +1498,7 @@ def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, ) def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]]: + positional_variadic_arg = [] positional_args, subcommand_args, optional_args = [], [], [] for field_name, field_info in _get_model_fields(model).items(): if _CliSubCommand in field_info.metadata: @@ -1511,17 +1516,31 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo] ) subcommand_args.append((field_name, field_info)) elif _CliPositionalArg in field_info.metadata: - if not field_info.is_required(): - raise SettingsError(f'positional argument {model.__name__}.{field_name} has a default value') + alias_names, *_ = _get_alias_names(field_name, field_info) + if len(alias_names) > 1: + raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases') + is_append_action = _annotation_contains_types( + field_info.annotation, (list, set, dict, Sequence, Mapping), is_strip_annotated=True + ) + if not is_append_action: + positional_args.append((field_name, field_info)) else: - alias_names, *_ = _get_alias_names(field_name, field_info) - if len(alias_names) > 1: - raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases') - positional_args.append((field_name, field_info)) + positional_variadic_arg.append((field_name, field_info)) else: self._verify_cli_flag_annotations(model, field_name, field_info) optional_args.append((field_name, field_info)) - return positional_args + subcommand_args + optional_args + + if positional_variadic_arg: + if len(positional_variadic_arg) > 1: + field_names = ', '.join([name for name, info in positional_variadic_arg]) + raise SettingsError(f'{model.__name__} has multiple variadic positonal arguments: {field_names}') + elif subcommand_args: + field_names = ', '.join([name for name, info in positional_variadic_arg + subcommand_args]) + raise SettingsError( + f'{model.__name__} has variadic positonal arguments and subcommand arguments: {field_names}' + ) + + return positional_args + positional_variadic_arg + subcommand_args + optional_args @property def root_parser(self) -> T: @@ -1727,11 +1746,9 @@ def _add_parser_args( self._cli_dict_args[kwargs['dest']] = field_info.annotation if _CliPositionalArg in field_info.metadata: - kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper()) - arg_names = [kwargs['dest']] - del kwargs['dest'] - del kwargs['required'] - flag_prefix = '' + arg_names, flag_prefix = self._convert_positional_arg( + kwargs, field_info, preferred_alias, model_default + ) self._convert_bool_flag(kwargs, field_info, model_default) @@ -1787,6 +1804,27 @@ def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, mode BooleanOptionalAction if sys.version_info >= (3, 9) else f'store_{str(not default).lower()}' ) + def _convert_positional_arg( + self, kwargs: dict[str, Any], field_info: FieldInfo, preferred_alias: str, model_default: Any + ) -> tuple[list[str], str]: + flag_prefix = '' + arg_names = [kwargs['dest']] + kwargs['default'] = PydanticUndefined + kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper()) + + # Note: CLI positional args are always strictly required at the CLI. Therefore, use field_info.is_required in + # conjunction with model_default instead of the derived kwargs['required']. + is_required = field_info.is_required() and model_default is PydanticUndefined + if kwargs.get('action') == 'append': + del kwargs['action'] + kwargs['nargs'] = '+' if is_required else '*' + elif not is_required: + kwargs['nargs'] = '?' + + del kwargs['dest'] + del kwargs['required'] + return arg_names, flag_prefix + def _get_arg_names( self, arg_prefix: str, diff --git a/tests/test_source_cli.py b/tests/test_source_cli.py index 35bfcda..3c59016 100644 --- a/tests/test_source_cli.py +++ b/tests/test_source_cli.py @@ -1297,6 +1297,45 @@ class Cfg(BaseSettings): assert cfg.model_dump() == {'child': {'name': 'new name a', 'diff_a': 'new diff a'}} +def test_cli_optional_positional_arg(env): + class Main(BaseSettings): + model_config = SettingsConfigDict( + cli_parse_args=True, + cli_enforce_required=True, + ) + + value: CliPositionalArg[int] = 123 + + assert CliApp.run(Main, cli_args=[]).model_dump() == {'value': 123} + + env.set('VALUE', '456') + assert CliApp.run(Main, cli_args=[]).model_dump() == {'value': 456} + + assert CliApp.run(Main, cli_args=['789']).model_dump() == {'value': 789} + + +def test_cli_variadic_positional_arg(env): + class MainRequired(BaseSettings): + model_config = SettingsConfigDict(cli_parse_args=True) + + values: CliPositionalArg[List[int]] + + class MainOptional(MainRequired): + values: CliPositionalArg[List[int]] = [1, 2, 3] + + assert CliApp.run(MainOptional, cli_args=[]).model_dump() == {'values': [1, 2, 3]} + with pytest.raises(SettingsError, match='error parsing CLI: the following arguments are required: VALUES'): + CliApp.run(MainRequired, cli_args=[], cli_exit_on_error=False) + + env.set('VALUES', '[4,5,6]') + assert CliApp.run(MainOptional, cli_args=[]).model_dump() == {'values': [4, 5, 6]} + with pytest.raises(SettingsError, match='error parsing CLI: the following arguments are required: VALUES'): + CliApp.run(MainRequired, cli_args=[], cli_exit_on_error=False) + + assert CliApp.run(MainOptional, cli_args=['7', '8', '9']).model_dump() == {'values': [7, 8, 9]} + assert CliApp.run(MainRequired, cli_args=['7', '8', '9']).model_dump() == {'values': [7, 8, 9]} + + def test_cli_enums(capsys, monkeypatch): class Pet(IntEnum): dog = 0 @@ -1416,13 +1455,26 @@ class PositionalArgNotOutermost(BaseSettings, cli_parse_args=True): PositionalArgNotOutermost() with pytest.raises( - SettingsError, match='positional argument PositionalArgHasDefault.pos_arg has a default value' + SettingsError, + match='MultipleVariadicPositionialArgs has multiple variadic positonal arguments: strings, numbers', + ): + + class MultipleVariadicPositionialArgs(BaseSettings, cli_parse_args=True): + strings: CliPositionalArg[List[str]] + numbers: CliPositionalArg[List[int]] + + MultipleVariadicPositionialArgs() + + with pytest.raises( + SettingsError, + match='VariadicPositionialArgAndSubCommand has variadic positonal arguments and subcommand arguments: strings, sub_cmd', ): - class PositionalArgHasDefault(BaseSettings, cli_parse_args=True): - pos_arg: CliPositionalArg[str] = 'bad' + class VariadicPositionialArgAndSubCommand(BaseSettings, cli_parse_args=True): + strings: CliPositionalArg[List[str]] + sub_cmd: CliSubCommand[SubCmd] - PositionalArgHasDefault() + VariadicPositionialArgAndSubCommand() with pytest.raises( SettingsError, match=re.escape("cli_parse_args must be List[str] or Tuple[str, ...], recieved ")