Skip to content

Commit

Permalink
Added first test that shows mentioned deficites
Browse files Browse the repository at this point in the history
  • Loading branch information
MrWhatZitToYaa committed Oct 14, 2024
1 parent 8ad3e29 commit 84053be
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,44 @@ def add_arguments_to_parser(self, parser):
assert cli.model.num_classes == 5


def test_lightning_cli_link_arguments_init():
# Will not work without init_args ("--data.init_args.batch_size=12")
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("data.batch_size", "model.init_args.batch_size")

cli_args = [
"--data=tests_pytorch.test_cli.BoringDataModuleBatchSizeAndClasses",
"--model=tests_pytorch.test_cli.BoringModelRequiredClasses",
"--data.init_args.batch_size=12",
"--model.init_args.num_classes=5",
]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(run=False)

assert cli.datamodule.batch_size == 12

# Will work without init_args ("--data.batch_size=12")
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
pass

cli_args = [
"--data=tests_pytorch.test_cli.BoringDataModuleBatchSizeAndClasses",
"--model=tests_pytorch.test_cli.BoringModelRequiredClasses",
"--data.batch_size=12",
"--model.num_classes=12",
]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(run=False)

print(cli.config)

assert cli.datamodule.batch_size == 12


class EarlyExitTestModel(BoringModel):
def on_fit_start(self):
raise MisconfigurationException("Error on fit start")
Expand Down

0 comments on commit 84053be

Please sign in to comment.