From 84053beddaf5a87b4a3e02763ff505dc547bd348 Mon Sep 17 00:00:00 2001 From: Mr WhatZitTooYa Date: Mon, 14 Oct 2024 16:53:58 -0400 Subject: [PATCH] Added first test that shows mentioned deficites --- tests/tests_pytorch/test_cli.py | 38 +++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 56b58d4d157a1..6fded08f332f5 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -597,6 +597,44 @@ def add_arguments_to_parser(self, parser): assert cli.model.num_classes == 5 +def test_lightning_cli_link_arguments_init(): + # Will not work without init_args ("--data.init_args.batch_size=12") + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.link_arguments("data.batch_size", "model.init_args.batch_size") + + cli_args = [ + "--data=tests_pytorch.test_cli.BoringDataModuleBatchSizeAndClasses", + "--model=tests_pytorch.test_cli.BoringModelRequiredClasses", + "--data.init_args.batch_size=12", + "--model.init_args.num_classes=5", + ] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = MyLightningCLI(run=False) + + assert cli.datamodule.batch_size == 12 + + # Will work without init_args ("--data.batch_size=12") + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + pass + + cli_args = [ + "--data=tests_pytorch.test_cli.BoringDataModuleBatchSizeAndClasses", + "--model=tests_pytorch.test_cli.BoringModelRequiredClasses", + "--data.batch_size=12", + "--model.num_classes=12", + ] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = MyLightningCLI(run=False) + + print(cli.config) + + assert cli.datamodule.batch_size == 12 + + class EarlyExitTestModel(BoringModel): def on_fit_start(self): raise MisconfigurationException("Error on fit start")