Skip to content

Commit

Permalink
finish test
Browse files Browse the repository at this point in the history
  • Loading branch information
surgura committed Jun 5, 2024
1 parent e24c809 commit a3294f7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
15 changes: 11 additions & 4 deletions tests/instruments/test_ship_underwater_st.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from virtual_ship.instruments.ship_underwater_st import simulate_ship_underwater_st
import xarray as xr
from typing import Callable
from datetime import timedelta


def test_simulate_ship_underwater_st(tmp_dir_factory: Callable[[str], str]) -> None:
Expand All @@ -21,8 +22,14 @@ def test_simulate_ship_underwater_st(tmp_dir_factory: Callable[[str], str]) -> N
Spacetime(Location(7, 0), base_time + np.timedelta64(1, "s")),
]
expected_obs = [
{"salinity": 1, "temperature": 2, "lat": 3, "lon": 0},
{"salinity": 5, "temperature": 6, "lat": 7, "lon": 0},
{"salinity": 1, "temperature": 2, "lat": 3, "lon": 0, "time": base_time},
{
"salinity": 5,
"temperature": 6,
"lat": 7,
"lon": 0,
"time": base_time + np.timedelta64(1, "s"),
},
]

fieldset = FieldSet.from_data(
Expand All @@ -41,7 +48,7 @@ def test_simulate_ship_underwater_st(tmp_dir_factory: Callable[[str], str]) -> N
{
"lon": 0,
"lat": np.array([expected_obs[0]["lat"], expected_obs[1]["lat"]]),
"time": np.array([base_time, base_time + np.timedelta64(1, "s")]),
"time": np.array([expected_obs[0]["time"], expected_obs[1]["time"]]),
},
)

Expand All @@ -57,7 +64,7 @@ def test_simulate_ship_underwater_st(tmp_dir_factory: Callable[[str], str]) -> N
results = xr.open_zarr(out_file_name)

assert len(results.trajectory) == 1
assert len(results.sel(trajectory=0).obs == len(sample_points))
assert len(results.sel(trajectory=0).obs) == len(sample_points)

for i, (obs_i, exp) in enumerate(
zip(results.sel(trajectory=0).obs, expected_obs, strict=True)
Expand Down
22 changes: 15 additions & 7 deletions virtual_ship/instruments/ship_underwater_st.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Ship salinity and temperature."""

import numpy as np
from parcels import FieldSet, JITParticle, ParticleSet, Variable
from parcels import FieldSet, ScipyParticle, ParticleSet, Variable

from ..spacetime import Spacetime

_ShipSTParticle = JITParticle.add_variables(
# we specifically use ScipyParticle because we have many small calls to execute
# JITParticle would require compilation every time
# this ends up being faster
_ShipSTParticle = ScipyParticle.add_variables(
[
Variable("salinity", dtype=np.float32, initial=np.nan),
Variable("temperature", dtype=np.float32, initial=np.nan),
Expand Down Expand Up @@ -40,32 +43,37 @@ def simulate_ship_underwater_st(
:param out_file_name: The file to write the results to.
:param depth: The depth at which to measure. 0 is water surface, negative is into the water.
:param sample_points: The places and times to sample at.
:param sample_dt: Time between each sample point.
:param output_dt: Period of writing to output file.
"""
sample_points.sort(key=lambda p: p.time)

particleset = ParticleSet.from_list(
fieldset=fieldset,
pclass=_ShipSTParticle,
lon=0.0, # initial lat/lon are irrelevant and will be overruled later.
lon=0.0, # initial lat/lon are irrelevant and will be overruled later
lat=0.0,
depth=depth,
time=0, # same for time
)

# define output file for the simulation
out_file = particleset.ParticleFile(
name=out_file_name,
)
# the default outputdt is good(infinite), as we want to just want to write at the end of every call to 'execute'
out_file = particleset.ParticleFile(name=out_file_name)

# iterate over each points, manually set lat lon time, then
# execute the particle set for one step, performing one set of measurement
for point in sample_points:
particleset.lon_nextloop[:] = point.location.lon
particleset.lat_nextloop[:] = point.location.lat
particleset.time_nextloop[:] = fieldset.time_origin.reltime(point.time)

# perform one step using the particleset
# dt and runtime are set so exactly one step is made.
particleset.execute(
[_sample_salinity, _sample_temperature],
dt=1,
runtime=1,
verbose_progress=False,
output_file=out_file,
)
out_file.write(particleset, time=particleset[0].time)

0 comments on commit a3294f7

Please sign in to comment.