Skip to content

Commit

Permalink
fixes some issues when parameters serve as the rate source for RadauN…
Browse files Browse the repository at this point in the history
…ew and Birkhoff
  • Loading branch information
robfalck committed Jan 22, 2025
1 parent 9dd9f53 commit 13acb5b
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 127 deletions.
48 changes: 20 additions & 28 deletions dymos/examples/brachistochrone/doc/test_doc_brachistochrone.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,11 @@ def test_brachistochrone_for_docs_gauss_lobatto(self):
#
# Set the initial values
#
p['traj.phase0.t_initial'] = 0.0
p['traj.phase0.t_duration'] = 2.0

p.set_val('traj.phase0.states:x', phase.interp('x', ys=[0, 10]))
p.set_val('traj.phase0.states:y', phase.interp('y', ys=[10, 5]))
p.set_val('traj.phase0.states:v', phase.interp('v', ys=[0, 9.9]))
p.set_val('traj.phase0.controls:theta', phase.interp('theta', ys=[5, 100.5]))
phase.set_time_val(initial=0, duration=2.0)
phase.set_state_val('x', [0, 10])
phase.set_state_val('y', [10, 5])
phase.set_state_val('v', [0, 9.9])
phase.set_control_val('theta', [5, 100.5], units='deg')

#
# Solve for the optimal trajectory
Expand Down Expand Up @@ -187,13 +185,11 @@ def test_brachistochrone_for_docs_radau(self):
#
# Set the initial values
#
p['traj.phase0.t_initial'] = 0.0
p['traj.phase0.t_duration'] = 2.0

p.set_val('traj.phase0.states:x', phase.interp('x', ys=[0, 10]))
p.set_val('traj.phase0.states:y', phase.interp('y', ys=[10, 5]))
p.set_val('traj.phase0.states:v', phase.interp('v', ys=[0, 9.9]))
p.set_val('traj.phase0.controls:theta', phase.interp('theta', ys=[5, 100.5]))
phase.set_time_val(initial=0, duration=2.0)
phase.set_state_val('x', [0, 10])
phase.set_state_val('y', [10, 5])
phase.set_state_val('v', [0, 9.9])
phase.set_control_val('theta', [5, 100.5], units='deg')

#
# Solve for the optimal trajectory
Expand Down Expand Up @@ -274,13 +270,11 @@ def test_brachistochrone_for_docs_coloring_demo(self):
#
# Set the initial values
#
p['traj.phase0.t_initial'] = 0.0
p['traj.phase0.t_duration'] = 2.0

p.set_val('traj.phase0.states:x', phase.interp('x', ys=[0, 10]))
p.set_val('traj.phase0.states:y', phase.interp('y', ys=[10, 5]))
p.set_val('traj.phase0.states:v', phase.interp('v', ys=[0, 9.9]))
p.set_val('traj.phase0.controls:theta', phase.interp('theta', ys=[5, 100.5]))
phase.set_time_val(initial=0, duration=2.0)
phase.set_state_val('x', [0, 10])
phase.set_state_val('y', [10, 5])
phase.set_state_val('v', [0, 9.9])
phase.set_control_val('theta', [5, 100.5], units='deg')

#
# Solve for the optimal trajectory
Expand Down Expand Up @@ -369,13 +363,11 @@ def test_brachistochrone_for_docs_coloring_demo_solve_segments(self):
#
# Set the initial values
#
p['traj.phase0.t_initial'] = 0.0
p['traj.phase0.t_duration'] = 2.0

p.set_val('traj.phase0.states:x', phase.interp('x', ys=[0, 10]))
p.set_val('traj.phase0.states:y', phase.interp('y', ys=[10, 5]))
p.set_val('traj.phase0.states:v', phase.interp('v', ys=[0, 9.9]))
p.set_val('traj.phase0.controls:theta', phase.interp('theta', ys=[5, 100.5]))
phase.set_time_val(initial=0, duration=2.0)
phase.set_state_val('x', [0, 10])
phase.set_state_val('y', [10, 5])
phase.set_state_val('v', [0, 9.9])
phase.set_control_val('theta', [5, 100.5], units='deg')

#
# Solve for the optimal trajectory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def test_integrate_polynomial_control_rate2_radau(self):
@use_tempdirs
class TestIntegrateTimeParamAndState(unittest.TestCase):

def _test_transcription(self, transcription=dm.GaussLobatto, time_name='time'):
def _test_transcription(self, transcription, time_name='time'):
#
# Define the OpenMDAO problem
#
Expand All @@ -758,8 +758,7 @@ def _test_transcription(self, transcription=dm.GaussLobatto, time_name='time'):
# Define a Dymos Phase object with GaussLobatto Transcription
#
phase = dm.Phase(ode_class=BrachistochroneODE,
transcription=transcription(num_segments=10, order=3,
solve_segments='forward'))
transcription=transcription)

traj.add_phase(name='phase0', phase=phase)

Expand Down Expand Up @@ -861,8 +860,11 @@ def _test_transcription(self, transcription=dm.GaussLobatto, time_name='time'):
assert_timeseries_near_equal(time_sol, int_int_one_sol, time_sim, int_int_one_sim, rel_tolerance=1.0E-12)

def test_integrated_times_params_and_states(self):
for tx in (dm.GaussLobatto, dm.Radau):
tx_name = 'GaussLobatto' if tx is dm.GaussLobatto else 'Radau'
txs = {'GaussLobatto': dm.GaussLobatto(num_segments=10, order=3,
solve_segments='forward'),
'Radau': dm.Radau(num_segments=10, order=3, solve_segments='forward'),
'Birkhoff': dm.Birkhoff(num_nodes=30, solve_segments='forward')}
for tx_name, tx in txs.items():
for time_name in ('time', 'elapsed_time'):
with self.subTest(msg=f'{tx_name}: time_name=\'{time_name}\''):
self._test_transcription(transcription=tx, time_name=time_name)
Expand Down
2 changes: 1 addition & 1 deletion dymos/examples/oscillator/oscillator_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def setup(self):
self.add_input('A', val=np.ones((nn, 2, 2)))

# Output
self.add_output('x_dot', val=np.zeros((nn, 2)))
self.add_output('x_dot', val=np.zeros((nn, 2)), units='1/s')

self.declare_partials(of='*', wrt='*', method='fd')

Expand Down
120 changes: 39 additions & 81 deletions dymos/examples/oscillator/test/test_oscillator_vector_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,86 +14,44 @@ def test_matrix_param(self):

from dymos.examples.oscillator.oscillator_ode import OscillatorVectorODE

# Instantiate an OpenMDAO Problem instance.
prob = om.Problem()
prob.driver = om.ScipyOptimizeDriver()
prob.driver.options["optimizer"] = 'SLSQP'

static_params = False

t = dm.Radau(num_segments=2, order=3)
phase = dm.Phase(ode_class=OscillatorVectorODE, transcription=t,
ode_init_kwargs={'static_params': static_params})

phase.set_time_options(fix_initial=True, duration_bounds=(1, 2), duration_ref=1)
phase.add_state("x", fix_initial=True, rate_source="x_dot")

A_mat = np.array(
[
[0, 1],
[-1, 0]
]
)

# argument "dynamic" doesn't seem to help
phase.add_parameter("A", val=A_mat, targets=["A"], static_target=static_params)
phase.add_objective("time", loc="final", scaler=1)

traj = dm.Trajectory()
traj.add_phase("phase0", phase)

prob.model.add_subsystem("traj", traj)

prob.driver.declare_coloring()
prob.setup(force_alloc_complex=True)
phase.set_state_val('x', vals=[[1, 0], [1, 0]])

dm.run_problem(prob, run_driver=True, simulate=True, make_plots=True)
t_f = prob.get_val('traj.phase0.timeseries.time')[-1]
final_state = prob.get_val('traj.phase0.timeseries.x')[-1, :]
assert_near_equal(final_state, np.array([np.cos(t_f), -np.sin(t_f)]).ravel(),
tolerance=1e-5)

def test_matrix_static_param(self):

from dymos.examples.oscillator.oscillator_ode import OscillatorVectorODE
radau = dm.Radau(num_segments=2, order=3)
birkhoff = dm.Birkhoff(num_nodes=7)

# Instantiate an OpenMDAO Problem instance.
prob = om.Problem()
prob.driver = om.ScipyOptimizeDriver()
prob.driver.options["optimizer"] = 'SLSQP'

static_params = True

t = dm.Radau(num_segments=2, order=3)
phase = dm.Phase(ode_class=OscillatorVectorODE, transcription=t,
ode_init_kwargs={'static_params': static_params})

phase.set_time_options(fix_initial=True, duration_bounds=(1, 2), duration_ref=1)
phase.add_state("x", fix_initial=True, rate_source="x_dot")

A_mat = np.array(
[
[0, 1],
[-1, 0]
]
)

# argument "dynamic" doesn't seem to help
phase.add_parameter("A", val=A_mat, targets=["A"], static_target=static_params)
phase.add_objective("time", loc="final", scaler=1)

traj = dm.Trajectory()
traj.add_phase("phase0", phase)

prob.model.add_subsystem("traj", traj)

prob.driver.declare_coloring()
prob.setup(force_alloc_complex=True)
phase.set_state_val('x', vals=[[1, 0], [1, 0]])

dm.run_problem(prob, run_driver=True, simulate=True, make_plots=True)
t_f = prob.get_val('traj.phase0.timeseries.time')[-1]
final_state = prob.get_val('traj.phase0.timeseries.x')[-1, :]
assert_near_equal(final_state, np.array([np.cos(t_f), -np.sin(t_f)]).ravel(),
tolerance=1e-5)
for tx in (radau, birkhoff):
for static_params in (True, False):
with self.subTest(f'{static_params=}, {tx=}'):
prob = om.Problem()
prob.driver = om.ScipyOptimizeDriver()
prob.driver.options["optimizer"] = 'SLSQP'

phase = dm.Phase(ode_class=OscillatorVectorODE, transcription=tx,
ode_init_kwargs={'static_params': static_params})

phase.set_time_options(fix_initial=True, duration_bounds=(1, 2), duration_ref=1)
phase.add_state("x", fix_initial=True, rate_source="x_dot")

A_mat = np.array(
[
[0, 1],
[-1, 0]
]
)

phase.add_parameter("A", val=A_mat, targets=["A"], static_target=static_params)
phase.add_objective("time", loc="final", scaler=1)

traj = dm.Trajectory()
traj.add_phase("phase0", phase)

prob.model.add_subsystem("traj", traj)

prob.driver.declare_coloring()
prob.setup(force_alloc_complex=True)
phase.set_state_val('x', vals=[[1, 0], [1, 0]])

dm.run_problem(prob, run_driver=True, simulate=True, make_plots=True)
t_f = prob.get_val('traj.phase0.timeseries.time')[-1]
final_state = prob.get_val('traj.phase0.timeseries.x')[-1, :]
assert_near_equal(final_state, np.array([np.cos(t_f), -np.sin(t_f)]).ravel(),
tolerance=1e-5)
14 changes: 11 additions & 3 deletions dymos/transcriptions/pseudospectral/birkhoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,18 +255,24 @@ def setup_defects(self, phase):

def configure_defects(self, phase):
"""
Configure the continuity_comp and connect the collocation constraints.
Connect the collocation constraints.
Parameters
----------
phase : dymos.Phase
The phase object to which this transcription instance applies.
"""
num_nodes = self.grid_data.subset_num_nodes['all']
for name, options in phase.state_options.items():
rate_source_type = phase.classify_var(options['rate_source'])
rate_src_path = self._get_rate_source_path(name, phase)
if rate_src_path.startswith('parameter_vals:'):
src_idxs = om.slicer[np.zeros(num_nodes, dtype=int), ...]
else:
src_idxs = None

if rate_source_type not in ('state', 'ode'):
phase.connect(rate_src_path, f'f_computed:{name}')
phase.connect(rate_src_path, f'f_computed:{name}', src_indices=src_idxs)

def setup_duration_balance(self, phase):
"""
Expand Down Expand Up @@ -781,9 +787,11 @@ def get_parameter_connections(self, name, phase):
else:
src_idxs_raw = np.zeros(self.grid_data.subset_num_nodes['all'], dtype=int)
src_idxs = get_src_indices_by_row(src_idxs_raw, options['shape'])
endpoint_src_idxs = om.slicer[[0, -1], ...]
endpoint_src_idxs_raw = np.zeros(2, dtype=int)
endpoint_src_idxs = get_src_indices_by_row(endpoint_src_idxs_raw, options['shape'])
if options['shape'] == (1,):
src_idxs = src_idxs.ravel()
endpoint_src_idxs = endpoint_src_idxs.ravel()

connection_info.append((f'ode_all.{tgt}', (src_idxs,)))
connection_info.append((f'boundary_vals.{tgt}', (endpoint_src_idxs,)))
Expand Down
18 changes: 9 additions & 9 deletions dymos/transcriptions/pseudospectral/radau_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,17 +288,15 @@ def configure_defects(self, phase):
grid_data = self.grid_data
col_idxs = grid_data.subset_node_indices['col']

for name, options in phase.state_options.items():
rate_source_type = phase.classify_var(options['rate_source'])
if rate_source_type == 'parameter':
for name in phase.state_options:
rate_src_path = self._get_rate_source_path(state_name=name,phase=phase)
if rate_src_path.startswith('parameter_vals:'):
src_idxs = om.slicer[np.zeros_like(col_idxs), ...]
else:
src_idxs = om.slicer[col_idxs, ...]
if rate_source_type not in ('state'):
# Do not need to connect state rates whose value comes from another state.
rate_src_path = self._get_rate_source_path(name, phase)
phase.connect(rate_src_path, f'f_ode:{name}',
src_indices=src_idxs)

if not rate_src_path.startswith('states:'):
phase.connect(rate_src_path, f'f_ode:{name}', src_indices=src_idxs)

def setup_duration_balance(self, phase):
"""
Expand Down Expand Up @@ -847,9 +845,11 @@ def get_parameter_connections(self, name, phase):
else:
src_idxs_raw = np.zeros(self.grid_data.subset_num_nodes['all'], dtype=int)
src_idxs = get_src_indices_by_row(src_idxs_raw, options['shape'])
endpoint_src_idxs = om.slicer[[0, -1], ...]
endpoint_src_idxs_raw = np.zeros(2, dtype=int)
endpoint_src_idxs = get_src_indices_by_row(endpoint_src_idxs_raw, options['shape'])
if options['shape'] == (1,):
src_idxs = src_idxs.ravel()
endpoint_src_idxs = endpoint_src_idxs.ravel()

connection_info.append((f'ode_all.{tgt}', (src_idxs,)))
connection_info.append((f'boundary_vals.{tgt}', (endpoint_src_idxs,)))
Expand Down

0 comments on commit 13acb5b

Please sign in to comment.