Skip to content

Commit

Permalink
Test + support custom constructor edge case (#259)
Browse files Browse the repository at this point in the history
* Add custom constructor test for tuple types

* Test + support adversarial custom constructor edge cases

* Cleanup

* Fix comment

* Python 3.8

* Need is_positional() check

* More backwards compatibility

* Note Python 3.10 edge case
  • Loading branch information
brentyi authored Feb 19, 2025
1 parent fd76d72 commit 9f991c7
Show file tree
Hide file tree
Showing 5 changed files with 333 additions and 21 deletions.
7 changes: 5 additions & 2 deletions src/tyro/_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,11 @@ def get_value_from_arg(
# value, and the field default will be inspect.Parameter.empty.
if (
value in _fields.MISSING_AND_MISSING_NONPROP
and field.is_positional_call()
and arg.lowered.nargs in ("?", "*")
and arg.field.is_positional()
# nargs="?" is currently only used for optional positional
# arguments when the underlying nargs for the primitive
# constructor is 1. Logic for this is in _arguments.py.
and arg.lowered.nargs == "*"
):
value = []
should_cast = True
Expand Down
52 changes: 51 additions & 1 deletion tests/test_custom_constructors.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import json
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Tuple, Union

import numpy as np
import pytest
from helptext_utils import get_helptext_with_checks
from typing_extensions import Annotated, Literal, get_args

import tyro
Expand Down Expand Up @@ -136,3 +137,52 @@ def main(
with pytest.raises(SystemExit):
tyro.cli(main, args=["--field1"])
assert tyro.cli(main, args=["--field1", "a", "b"]) == ["a", "b"]


def test_min_length_custom_constructor_positional() -> None:
def main(
field1: tyro.conf.Positional[ListOfStringsWithMinimumLength], field2: int = 3
) -> ListOfStringsWithMinimumLength:
del field2
return field1

with pytest.raises(SystemExit):
tyro.cli(main, args=[])
assert tyro.cli(main, args=["a", "b"]) == ["a", "b"]


TupleCustomConstructor = Annotated[
Tuple[str, ...],
tyro.constructors.PrimitiveConstructorSpec(
nargs="*",
metavar="A TUPLE METAVAR",
is_instance=lambda x: isinstance(x, tuple)
and all(isinstance(i, str) for i in x),
instance_from_str=lambda args: tuple(args),
str_from_instance=lambda args: list(args),
),
]


def test_tuple_custom_constructors() -> None:
def main(field1: TupleCustomConstructor, field2: int = 3) -> tuple[str, ...]:
del field2
return field1

assert tyro.cli(main, args=["--field1", "a", "b"]) == ("a", "b")
assert tyro.cli(main, args=["--field1", "a"]) == ("a",)
assert tyro.cli(main, args=["--field1"]) == ()
assert "A TUPLE METAVAR" in get_helptext_with_checks(main)


def test_tuple_custom_constructors_positional() -> None:
def main(
field1: tyro.conf.Positional[TupleCustomConstructor], field2: int = 3
) -> tuple[str, ...]:
del field2
return field1

assert tyro.cli(main, args=["a", "b"]) == ("a", "b")
assert tyro.cli(main, args=["a"]) == ("a",)
assert tyro.cli(main, args=[]) == ()
assert "A TUPLE METAVAR" in get_helptext_with_checks(main)
121 changes: 113 additions & 8 deletions tests/test_positional_min_py38.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from __future__ import annotations

import sys
from typing import List, Optional, Tuple

import pytest
from helptext_utils import get_helptext_with_checks
from typing_extensions import Annotated

import tyro

Expand Down Expand Up @@ -31,11 +36,12 @@ def main(
assert tyro.cli(main, args="--x 1 --y 2 --z 3".split(" ")) == (1, 2, 3)


def test_nested_positional():
class A:
def __init__(self, a: int, hello_world: int, /, c: int):
self.hello_world = hello_world
class A:
def __init__(self, a: int, hello_world: int, /, c: int):
self.hello_world = hello_world


def test_nested_positional():
def nest1(a: int, b: int, thing: A, /, c: int) -> A:
return thing

Expand All @@ -45,11 +51,12 @@ def nest1(a: int, b: int, thing: A, /, c: int) -> A:
tyro.cli(nest1, args="0 1 2 3 4 4 --c 4".split(" "))


def test_nested_positional_alt():
class B:
def __init__(self, a: int, b: int, /, c: int):
pass
class B:
def __init__(self, a: int, b: int, /, c: int):
pass


def test_nested_positional_alt():
def nest2(a: int, b: int, /, thing: B, c: int):
return thing

Expand Down Expand Up @@ -116,3 +123,101 @@ def main(x: Tuple[int, int], y: Tuple[str, str], /):
return x, y

assert tyro.cli(main, args="1 2 3 4".split(" ")) == ((1, 2), ("3", "4"))


def make_list_of_strings_with_minimum_length(args: List[str]) -> List[str]:
if len(args) == 0:
raise ValueError("Expected at least one string")
return args


ListOfStringsWithMinimumLength = Annotated[
List[str],
tyro.constructors.PrimitiveConstructorSpec(
nargs="*",
metavar="STR [STR ...]",
is_instance=lambda x: isinstance(x, list)
and all(isinstance(i, str) for i in x),
instance_from_str=make_list_of_strings_with_minimum_length,
str_from_instance=lambda args: args,
),
]


def test_min_length_custom_constructor_positional() -> None:
def main(
field1: ListOfStringsWithMinimumLength, /, field2: int = 3
) -> ListOfStringsWithMinimumLength:
del field2
return field1

with pytest.raises(SystemExit):
tyro.cli(main, args=[])
assert tyro.cli(main, args=["a", "b"]) == ["a", "b"]


TupleCustomConstructor = Annotated[
Tuple[str, ...],
tyro.constructors.PrimitiveConstructorSpec(
nargs="*",
metavar="A TUPLE METAVAR",
is_instance=lambda x: isinstance(x, tuple)
and all(isinstance(i, str) for i in x),
instance_from_str=lambda args: tuple(args),
str_from_instance=lambda args: list(args),
),
]


def test_tuple_custom_constructors_positional() -> None:
def main(field1: TupleCustomConstructor, /, field2: int = 3) -> Tuple[str, ...]:
del field2
return field1

assert tyro.cli(main, args=["a", "b"]) == ("a", "b")
assert tyro.cli(main, args=["a"]) == ("a",)
assert tyro.cli(main, args=[]) == ()
assert "A TUPLE METAVAR" in get_helptext_with_checks(main)


TupleCustomConstructor2 = Annotated[
Tuple[str, ...],
tyro.constructors.PrimitiveConstructorSpec(
nargs="*",
metavar="A TUPLE METAVAR",
is_instance=lambda x: isinstance(x, tuple)
and all(isinstance(i, str) for i in x),
instance_from_str=lambda args: tuple(args),
str_from_instance=lambda args: list(args),
),
]


if sys.version_info >= (3, 11):

def test_tuple_custom_constructors_positional_default_none() -> None:
# Waiting for typing_extensions with this fixed:
# https://github.com/python/typing_extensions/issues/310
def main(
field1: TupleCustomConstructor2 | None = None, /, field2: int = 3
) -> Tuple[str, ...] | None:
del field2
return field1

assert tyro.cli(main, args=["a", "b"]) == ("a", "b")
assert tyro.cli(main, args=["a"]) == ("a",)
assert tyro.cli(main, args=[]) is None
assert "A TUPLE METAVAR" in get_helptext_with_checks(main)


def test_tuple_custom_constructors_positional_default_five() -> None:
def main(
field1: TupleCustomConstructor2 | int = 5, /, field2: int = 3
) -> Tuple[str, ...] | int:
del field2
return field1

assert tyro.cli(main, args=["a", "b"]) == ("a", "b")
assert tyro.cli(main, args=["a"]) == ("a",)
assert tyro.cli(main, args=[]) == 5
assert "A TUPLE METAVAR" in get_helptext_with_checks(main)
52 changes: 51 additions & 1 deletion tests/test_py311_generated/test_custom_constructors_generated.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import json
from typing import Annotated, Any, Dict, List, Literal, get_args
from typing import Annotated, Any, Dict, List, Literal, Tuple, get_args

import numpy as np
import pytest
from helptext_utils import get_helptext_with_checks

import tyro

Expand Down Expand Up @@ -135,3 +136,52 @@ def main(
with pytest.raises(SystemExit):
tyro.cli(main, args=["--field1"])
assert tyro.cli(main, args=["--field1", "a", "b"]) == ["a", "b"]


def test_min_length_custom_constructor_positional() -> None:
def main(
field1: tyro.conf.Positional[ListOfStringsWithMinimumLength], field2: int = 3
) -> ListOfStringsWithMinimumLength:
del field2
return field1

with pytest.raises(SystemExit):
tyro.cli(main, args=[])
assert tyro.cli(main, args=["a", "b"]) == ["a", "b"]


TupleCustomConstructor = Annotated[
Tuple[str, ...],
tyro.constructors.PrimitiveConstructorSpec(
nargs="*",
metavar="A TUPLE METAVAR",
is_instance=lambda x: isinstance(x, tuple)
and all(isinstance(i, str) for i in x),
instance_from_str=lambda args: tuple(args),
str_from_instance=lambda args: list(args),
),
]


def test_tuple_custom_constructors() -> None:
def main(field1: TupleCustomConstructor, field2: int = 3) -> tuple[str, ...]:
del field2
return field1

assert tyro.cli(main, args=["--field1", "a", "b"]) == ("a", "b")
assert tyro.cli(main, args=["--field1", "a"]) == ("a",)
assert tyro.cli(main, args=["--field1"]) == ()
assert "A TUPLE METAVAR" in get_helptext_with_checks(main)


def test_tuple_custom_constructors_positional() -> None:
def main(
field1: tyro.conf.Positional[TupleCustomConstructor], field2: int = 3
) -> tuple[str, ...]:
del field2
return field1

assert tyro.cli(main, args=["a", "b"]) == ("a", "b")
assert tyro.cli(main, args=["a"]) == ("a",)
assert tyro.cli(main, args=[]) == ()
assert "A TUPLE METAVAR" in get_helptext_with_checks(main)
Loading

0 comments on commit 9f991c7

Please sign in to comment.