diff --git a/mobly/suite_runner.py b/mobly/suite_runner.py index 3dd6c291..01f26b79 100644 --- a/mobly/suite_runner.py +++ b/mobly/suite_runner.py @@ -137,21 +137,63 @@ def _parse_cli_args(argv): return parser.parse_known_args(argv)[0] -def _find_suite_class(): - """Finds the test suite class in the current module. +def _find_suite_classes_in_module(module): + """Finds all test suite classes in the given module. - Walk through module members and find the subclass of BaseSuite. Only - one subclass is allowed in a module. + Walk through module members and find all classes that is a subclass of + BaseSuite. + + Args: + module: types.ModuleType, the module object to find test suite classes. Returns: - The test suite class in the test module. + A list of test suite classes. """ test_suites = [] - main_module_members = sys.modules['__main__'] - for _, module_member in main_module_members.__dict__.items(): + for _, module_member in module.__dict__.items(): if inspect.isclass(module_member): if issubclass(module_member, base_suite.BaseSuite): test_suites.append(module_member) + return test_suites + + +def _find_suite_class(): + """Finds the test suite class. + + First search for test suite classes in the __main__ module. If no test suite + class is found, search in the module that is calling + `suite_runner.run_suite_class`. + + Walk through module members and find the subclass of BaseSuite. Only + one subclass is allowed. + + Returns: + The test suite class in the test module. + """ + # Try to find test suites in __main__ module first. + test_suites = _find_suite_classes_in_module(sys.modules['__main__']) + + # Try to find test suites in the module of the caller of `run_suite_class`. + if len(test_suites) == 0: + logging.debug( + 'No suite class found in the __main__ module, trying to find it in the ' + 'module of the caller of suite_runner.run_suite_class method.' + ) + stacks = inspect.stack() + if len(stacks) < 2: + logging.debug( + 'Failed to get the caller stack of run_suite_class. Got stacks: %s', + stacks, + ) + else: + run_suite_class_caller_frame_info = inspect.stack()[2] + caller_frame = run_suite_class_caller_frame_info.frame + module = inspect.getmodule(caller_frame) + if module is None: + logging.debug('Failed to find module for frame %s', caller_frame) + else: + test_suites = _find_suite_classes_in_module(module) + if len(test_suites) != 1: logging.error( 'Expected 1 test class per file, found %s.', diff --git a/tests/lib/integration_test_suite.py b/tests/lib/integration_test_suite.py new file mode 100644 index 00000000..dd95ab04 --- /dev/null +++ b/tests/lib/integration_test_suite.py @@ -0,0 +1,31 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mobly import base_suite +from mobly import suite_runner +from tests.lib import integration_test + + +class IntegrationTestSuite(base_suite.BaseSuite): + + def setup_suite(self, config): + self.add_test_class(integration_test.IntegrationTest) + + +def main(): + suite_runner.run_suite_class() + + +if __name__ == "__main__": + main() diff --git a/tests/mobly/suite_runner_test.py b/tests/mobly/suite_runner_test.py index e4fad0a6..d380d23c 100755 --- a/tests/mobly/suite_runner_test.py +++ b/tests/mobly/suite_runner_test.py @@ -26,6 +26,7 @@ from mobly import test_runner from tests.lib import integration2_test from tests.lib import integration_test +from tests.lib import integration_test_suite class FakeTest1(base_test.BaseTestClass): @@ -140,8 +141,7 @@ def test_run_suite_with_failures(self, mock_exit): mock_exit.assert_called_once_with(1) @mock.patch('sys.exit') - @mock.patch.object(suite_runner, '_find_suite_class', autospec=True) - def test_run_suite_class(self, mock_find_suite_class, mock_exit): + def test_run_suite_class(self, mock_exit): tmp_file_path = self._gen_tmp_config_file() mock_cli_args = ['test_binary', f'--config={tmp_file_path}'] mock_called = mock.MagicMock() @@ -161,12 +161,14 @@ def teardown_suite(self): mock_called.teardown_suite() super().teardown_suite() - mock_find_suite_class.return_value = FakeTestSuite + sys.modules['__main__'].__dict__[FakeTestSuite.__name__] = FakeTestSuite with mock.patch.object(sys, 'argv', new=mock_cli_args): - suite_runner.run_suite_class() + try: + suite_runner.run_suite_class() + finally: + del sys.modules['__main__'].__dict__[FakeTestSuite.__name__] - mock_find_suite_class.assert_called_once() mock_called.setup_suite.assert_called_once_with() mock_called.teardown_suite.assert_called_once_with() mock_exit.assert_not_called() @@ -240,6 +242,24 @@ def setup_suite(self, config): {'FakeTest1': ['test_a']}, ) + @mock.patch('sys.exit') + @mock.patch.object(test_runner, 'TestRunner') + @mock.patch.object( + integration_test_suite.IntegrationTestSuite, 'setup_suite', autospec=True + ) + def test_run_suite_class_finds_suite_class_when_not_in_main_module( + self, mock_setup_suite, mock_test_runner_class, mock_exit + ): + mock_test_runner = mock_test_runner_class.return_value + mock_test_runner.results.is_all_pass = True + tmp_file_path = self._gen_tmp_config_file() + mock_cli_args = ['test_binary', f'--config={tmp_file_path}'] + + with mock.patch.object(sys, 'argv', new=mock_cli_args): + integration_test_suite.main() + + mock_setup_suite.assert_called_once() + def test_print_test_names(self): mock_test_class = mock.MagicMock() mock_cls_instance = mock.MagicMock()