Skip to content

Commit

Permalink
update to testing utils
Browse files Browse the repository at this point in the history
  • Loading branch information
robfalck committed Feb 4, 2025
1 parent 2f2c268 commit 1139da4
Showing 1 changed file with 131 additions and 37 deletions.
168 changes: 131 additions & 37 deletions dymos/utils/testing_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import io
import os

from packaging.version import Version

Expand Down Expand Up @@ -361,52 +360,147 @@ def _get_reports_dir(prob):
return get_reports_dir(prob)


# This duplicates OpenMDAO code and is needed for older versions of OpenMDAO (<= 3.19).
# Once support is dropped for < 3.19 we can get rid of this and use the version from OpenMDAO.
class set_env_vars(object):
class PhaseStub():
"""
Decorate a function to temporarily set some environment variables.
A stand-in for the Phase during config_io for testing.
Parameters
----------
**envs : dict
Keyword args corresponding to environment variables to set.
It just supports the classify_var method and returns "ode", the only value needed for unittests.
"""
def __init__(self):
self.nonlinear_solver = None
self.linear_solver = None

Attributes
----------
envs : dict
Saved mapping of environment var name to value.
def classify_var(self, name):
"""
A stand-in for classify_var that always sets the variable type to name.
Parameters
----------
name : str
The name of the variable to classify.
Returns
-------
str
The variable classification.
"""
return 'ode'

class SimpleODE(om.ExplicitComponent):
"""
A simple ODE from https://math.okstate.edu/people/yqwang/teaching/math4513_fall11/Notes/rungekutta.pdf
"""
def initialize(self):
"""
Declare options for SimpleODE.
"""
self.options.declare('num_nodes', types=(int,))

def setup(self):
"""
Add inputs and outputs to SimpleODE.
"""
nn = self.options['num_nodes']
self.add_input('x', shape=(nn,), units='s**2')
self.add_input('t', shape=(nn,), units='s')
self.add_input('p', shape=(nn,), units='s**2')

self.add_output('x_dot', shape=(nn,), units='s')

ar = np.arange(nn, dtype=int)
self.declare_partials(of='x_dot', wrt='x', rows=ar, cols=ar, val=1.0)
self.declare_partials(of='x_dot', wrt='t', rows=ar, cols=ar)
self.declare_partials(of='x_dot', wrt='p', rows=ar, cols=ar, val=1.0)

def compute(self, inputs, outputs):
"""
Compute the outputs of SimpleVectorizedODE.
Parameters
----------
inputs : Vector
Vector of inputs.
outputs : Vector
Vector of outputs.
"""
x = inputs['x']
t = inputs['t']
p = inputs['p']
outputs['x_dot'] = x - t**2 + p

def compute_partials(self, inputs, partials):
"""
Compute the partials of SimpleVectorizedODE.
Parameters
----------
inputs : Vector
Vector of inputs.
partials : Dictionary
Vector of partials.
"""
t = inputs['t']
partials['x_dot', 't'] = -2*t


class SimpleVectorizedODE(om.ExplicitComponent):
"""
A simple ODE from https://math.okstate.edu/people/yqwang/teaching/math4513_fall11/Notes/rungekutta.pdf
"""
def initialize(self):
"""
Declare options for SimpleVectorizedODE.
"""
self.options.declare('num_nodes', types=(int,))

def __init__(self, **envs):
def setup(self):
"""
Initialize attributes.
Add inputs and outputs to SimpleVectorizedODE.
"""
nn = self.options['num_nodes']
self.add_input('z', shape=(nn, 2), units='s**2')
self.add_input('t', shape=(nn,), units='s')
self.add_input('p', shape=(nn,), units='s**2')

self.add_output('z_dot', shape=(nn, 2), units='s')

cs = np.repeat(np.arange(nn, dtype=int), 2)
ar2 = np.arange(2 * nn, dtype=int)
dzdot_dz_pattern = np.arange(2 * nn, step=2, dtype=int)
self.declare_partials(of='z_dot', wrt='z', rows=dzdot_dz_pattern, cols=dzdot_dz_pattern, val=1.0)
self.declare_partials(of='z_dot', wrt='t', rows=ar2, cols=cs)
dzdot_dp_rows = np.arange(2 * nn, step=2, dtype=int)
dzdot_dp_cols = np.arange(nn, dtype=int)
self.declare_partials(of='z_dot', wrt='p', rows=dzdot_dp_rows, cols=dzdot_dp_cols, val=1.0)

def compute(self, inputs, outputs):
"""
Compute the outputs of SimpleVectorizedODE.
Parameters
----------
inputs : Vector
Vector of inputs.
outputs : Vector
Vector of outputs.
"""
self.envs = envs
z = inputs['z']
t = inputs['t']
p = inputs['p']
outputs['z_dot'][:, 0] = z[:, 0] - t**2 + p
outputs['z_dot'][:, 1] = 10 * t

def __call__(self, fnc):
def compute_partials(self, inputs, partials):
"""
Apply the decorator.
Compute the partials of SimpleVectorizedODE.
Parameters
----------
fnc : function
The function being wrapped.
inputs : Vector
Vector of inputs.
partials : Dictionary
Vector of partials.
"""
def wrap(*args, **kwargs):
saved = {}
try:
for k, v in self.envs.items():
saved[k] = os.environ.get(k)
os.environ[k] = v # will raise exception if v is not a string

return fnc(*args, **kwargs)
finally:
# put environment back as it was
for k, v in saved.items():
if v is None:
del os.environ[k]
else:
os.environ[k] = v

return wrap
t = inputs['t']
partials['z_dot', 't'][0::2] = -2*t
partials['z_dot', 't'][1::2] = 10

0 comments on commit 1139da4

Please sign in to comment.