Skip to content

Commit

Permalink
coverage improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
robfalck committed Feb 6, 2025
1 parent 820bb8c commit b4f678c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 101 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import unittest

import numpy as np
import openmdao.api as om
from openmdao.utils.assert_utils import assert_check_partials, assert_near_equal
import dymos as dm

from dymos.utils.testing_utils import SimpleODE
from dymos.utils.testing_utils import SimpleODE, SimpleVectorizedODE
from dymos.transcriptions.explicit_shooting.ode_evaluation_group import ODEEvaluationGroup


Expand Down Expand Up @@ -58,7 +59,61 @@ def test_eval(self):

assert_near_equal(p.get_val('ode_eval.state_rate_collector.state_rates:x_rate'), xdot_check)

cpd = p.check_partials(compact_print=True, method='cs')
cpd = p.check_partials(compact_print=True, method='cs', out_stream=None)
assert_check_partials(cpd)

def test_eval_vectorized(self):
ode_class = SimpleVectorizedODE
time_options = dm.phase.options.TimeOptionsDictionary()

time_options['targets'] = 't'
time_options['units'] = 's'

state_options = {'z': dm.phase.options.StateOptionsDictionary()}

state_options['z']['shape'] = (2,)
state_options['z']['units'] = 's**2'
state_options['z']['rate_source'] = 'z_dot'
state_options['z']['targets'] = ['z']

param_options = {'p': dm.phase.options.ParameterOptionsDictionary()}

param_options['p']['shape'] = (1,)
param_options['p']['units'] = 's**2'
param_options['p']['targets'] = ['p']

control_options = {}

p = om.Problem()

igd = dm.GaussLobattoGrid(num_segments=1, nodes_per_seg=3, compressed=False)

p.model.add_subsystem('ode_eval', ODEEvaluationGroup(ode_class,
input_grid_data=igd,
time_options=time_options,
state_options=state_options,
parameter_options=param_options,
control_options=control_options,
ode_init_kwargs=None))
p.setup(check=False, force_alloc_complex=True)

p.model.ode_eval.set_segment_index(0)
p.set_val('ode_eval.states:z', [1.25, 0.0])
p.set_val('ode_eval.time', [2.2])
p.set_val('ode_eval.parameters:p', [1.0])

p.run_model()

z = p.get_val('ode_eval.states:z')
p_ = p.get_val('ode_eval.parameters:p')
t = p.get_val('ode_eval.time')

z_rate = p.get_val('ode_eval.state_rate_collector.state_rates:z_rate')

assert_near_equal(z_rate[0, 0], z[:, 0] - t**2 + p_)
assert_near_equal(z_rate[0, 1], 10 * t)

cpd = p.check_partials(compact_print=True, method='cs', out_stream=None)
assert_check_partials(cpd)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,102 +239,3 @@ def configure_io(self, phase):

if var_type == 'ode':
self.connect(f'ode_all.{rate_source}', f'f_computed:{name}')

def _get_rate_source_path(self, state_name, nodes, phase):
"""
Return the rate source location and indices for a given state name.
Parameters
----------
state_name : str
Name of the state.
nodes : str
One of ['col', 'all'].
phase : dymos.Phase
Phase object containing the rate source.
Returns
-------
str
Path to the rate source.
ndarray
Array of source indices.
"""
gd = self.grid_data
try:
var = phase.state_options[state_name]['rate_source']
except RuntimeError:
raise ValueError(f"state '{state_name}' in phase '{phase.name}' was not given a "
"rate_source")

# Note the rate source must be shape-compatible with the state
var_type = phase.classify_var(var)

# Determine the path to the variable
if var_type == 't':
rate_path = 't'
node_idxs = gd.subset_node_indices[nodes]
elif var_type == 't_phase':
rate_path = 't_phase'
node_idxs = gd.subset_node_indices[nodes]
elif var_type == 'state':
rate_path = f'states:{var}'
# Find the state_input indices which occur at segment endpoints, and repeat them twice
state_input_idxs = gd.subset_node_indices['state_input']
repeat_idxs = np.ones_like(state_input_idxs)
if self.options['compressed']:
segment_end_idxs = gd.subset_node_indices['segment_ends'][1:-1]
# Repeat nodes that are on segment bounds (but not the first or last nodes in the phase)
nodes_to_repeat = list(set(state_input_idxs).intersection(segment_end_idxs))
# Now find these nodes in the state input indices
idxs_of_ntr_in_state_inputs = np.where(np.isin(state_input_idxs, nodes_to_repeat))[0]
# All state input nodes are used once, but nodes_to_repeat are used twice
repeat_idxs[idxs_of_ntr_in_state_inputs] = 2
# Now we have a way of mapping the state input indices to all nodes
map_input_node_idxs_to_all = np.repeat(np.arange(gd.subset_num_nodes['state_input'],
dtype=int), repeats=repeat_idxs)
# Now select the subset of nodes we want to use.
node_idxs = map_input_node_idxs_to_all[gd.subset_node_indices[nodes]]
elif var_type == 'indep_control':
rate_path = f'control_values:{var}'
node_idxs = gd.subset_node_indices[nodes]
elif var_type == 'input_control':
rate_path = f'control_values:{var}'
node_idxs = gd.subset_node_indices[nodes]
elif var_type == 'control_rate':
control_name = var[:-5]
rate_path = f'control_rates:{control_name}_rate'
node_idxs = gd.subset_node_indices[nodes]
elif var_type == 'control_rate2':
control_name = var[:-6]
rate_path = f'control_rates:{control_name}_rate2'
node_idxs = gd.subset_node_indices[nodes]
elif var_type == 'indep_polynomial_control':
rate_path = f'control_values:{var}'
node_idxs = gd.subset_node_indices[nodes]
elif var_type == 'input_polynomial_control':
rate_path = f'control_values:{var}'
node_idxs = gd.subset_node_indices[nodes]
elif var_type == 'polynomial_control_rate':
control_name = var[:-5]
rate_path = f'control_rates:{control_name}_rate'
node_idxs = gd.subset_node_indices[nodes]
elif var_type == 'polynomial_control_rate2':
control_name = var[:-6]
rate_path = f'control_rates:{control_name}_rate2'
node_idxs = gd.subset_node_indices[nodes]
elif var_type == 'parameter':
rate_path = f'parameter_vals:{var}'
dynamic = not phase.parameter_options[var]['static_target']
if dynamic:
node_idxs = np.zeros(gd.subset_num_nodes[nodes], dtype=int)
else:
node_idxs = np.zeros(1, dtype=int)
else:
# Failed to find variable, assume it is in the ODE
rate_path = f'ode_all.{var}'
node_idxs = gd.subset_node_indices[nodes]

src_idxs = om.slicer[node_idxs, ...]

return rate_path, src_idxs

0 comments on commit b4f678c

Please sign in to comment.