From 598909f3d4384cc2b9afb7a3b17fd0ce39741ed9 Mon Sep 17 00:00:00 2001 From: Bret Naylor Date: Thu, 7 Dec 2023 09:24:24 -0500 Subject: [PATCH] fixing mpi tests --- .../test/test_multi_phase_restart.py | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/dymos/examples/finite_burn_orbit_raise/test/test_multi_phase_restart.py b/dymos/examples/finite_burn_orbit_raise/test/test_multi_phase_restart.py index ad55821bc..8657a0637 100644 --- a/dymos/examples/finite_burn_orbit_raise/test/test_multi_phase_restart.py +++ b/dymos/examples/finite_burn_orbit_raise/test/test_multi_phase_restart.py @@ -4,8 +4,7 @@ import openmdao.api as om from openmdao.utils.assert_utils import assert_near_equal, assert_warnings, assert_no_warning from openmdao.utils.testing_utils import use_tempdirs, require_pyoptsparse -from openmdao.utils.mpi import MPI -import scipy +from openmdao.utils.mpi import MPI, multi_proc_exception_check from dymos.examples.finite_burn_orbit_raise.finite_burn_orbit_raise_problem import two_burn_orbit_raise_problem from dymos.utils.testing_utils import assert_cases_equal @@ -25,9 +24,10 @@ def test_ex_two_burn_orbit_raise_connected(self): compressed=False, optimizer=optimizer, show_output=False, connected=True) - if p.model.traj.phases.burn2 in p.model.traj.phases._subsystems_myproc: - assert_near_equal(p.get_val('traj.burn2.states:deltav')[0], 0.3995, - tolerance=4.0E-3) + with multi_proc_exception_check(p.comm): + if p.model.traj.phases.burn2 in p.model.traj.phases._subsystems_myproc: + assert_near_equal(p.get_val('traj.burn2.states:deltav')[0], 0.3995, + tolerance=4.0E-3) case1 = om.CaseReader('dymos_solution.db').get_case('final') sim_case1 = om.CaseReader('dymos_simulation.db').get_case('final') @@ -54,9 +54,10 @@ def test_restart_from_solution_radau(self): case1 = om.CaseReader('dymos_solution.db').get_case('final') sim_case1 = om.CaseReader('dymos_simulation.db').get_case('final') - if p.model.traj.phases.burn2 in p.model.traj.phases._subsystems_myproc: - assert_near_equal(p.get_val('traj.burn2.states:deltav')[-1], 0.3995, - tolerance=2.0E-3) + with multi_proc_exception_check(p.comm): + if p.model.traj.phases.burn2 in p.model.traj.phases._subsystems_myproc: + assert_near_equal(p.get_val('traj.burn2.states:deltav')[-1], 0.3995, + tolerance=2.0E-3) # Run again without an actual optimzier two_burn_orbit_raise_problem(transcription='radau', transcription_order=3, @@ -101,9 +102,10 @@ def test_ex_two_burn_orbit_raise_connected(self): if (issubclass(warn.category, category) and str(warn.message) == msg): raise AssertionError(f"Saw unexpected warning {category.__name__}: {msg}") - if p.model.traj.phases.burn2 in p.model.traj.phases._subsystems_myproc: - assert_near_equal(p.get_val('traj.burn2.states:deltav')[0], 0.3995, - tolerance=4.0E-3) + with multi_proc_exception_check(p.comm): + if p.model.traj.phases.burn2 in p.model.traj.phases._subsystems_myproc: + assert_near_equal(p.get_val('traj.burn2.states:deltav')[0], 0.3995, + tolerance=4.0E-3) case1 = om.CaseReader('dymos_solution.db').get_case('final') sim_case1 = om.CaseReader('dymos_simulation.db').get_case('final') @@ -117,7 +119,7 @@ def test_ex_two_burn_orbit_raise_connected(self): case2 = om.CaseReader('dymos_solution2.db').get_case('final') sim_case2 = om.CaseReader('dymos_simulation2.db').get_case('final') -# + # Verify that the second case has the same inputs and outputs assert_cases_equal(case1, case2, tol=1.0E-8) assert_cases_equal(sim_case1, sim_case2, tol=1.0E-8) @@ -130,9 +132,10 @@ def test_restart_from_solution_radau_to_connected(self): compressed=False, optimizer=optimizer, show_output=False, connected=True) - if p.model.traj.phases.burn2 in p.model.traj.phases._subsystems_myproc: - assert_near_equal(p.get_val('traj.burn2.states:deltav')[0], 0.3995, - tolerance=4.0E-3) + with multi_proc_exception_check(p.comm): + if p.model.traj.phases.burn2 in p.model.traj.phases._subsystems_myproc: + assert_near_equal(p.get_val('traj.burn2.states:deltav')[0], 0.3995, + tolerance=4.0E-3) case1 = om.CaseReader('dymos_solution.db').get_case('final') sim_case1 = om.CaseReader('dymos_simulation.db').get_case('final')