Skip to content

Commit

Permalink
Restore noise simulation unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
tskisner committed Nov 17, 2020
1 parent d2b9f4d commit 5874e56
Show file tree
Hide file tree
Showing 7 changed files with 396 additions and 445 deletions.
4 changes: 4 additions & 0 deletions src/toast/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def __repr__(self):
val += "\n>"
return val

def __del__(self):
if hasattr(self, "obs"):
self.clear()

@property
def comm(self):
"""The toast.Comm over which the data is distributed."""
Expand Down
53 changes: 30 additions & 23 deletions src/toast/future_ops/sim_tod_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

import numpy as np

from scipy import interpolate

from .. import rng

from ..timing import function_timer

from ..traits import trait_docs, Int, Unicode
Expand All @@ -21,17 +25,17 @@

@function_timer
def sim_noise_timestream(
realization,
telescope,
component,
obsindx,
detindx,
rate,
firstsamp,
samples,
oversample,
freq,
psd,
realization=0,
telescope=0,
component=0,
obsindx=0,
detindx=0,
rate=1.0,
firstsamp=0,
samples=0,
oversample=2,
freq=None,
psd=None,
py=False,
):
"""Generate a noise timestream, given a starting RNG state.
Expand Down Expand Up @@ -124,7 +128,9 @@ def sim_noise_timestream(
logfreq = np.log10(freq + freqshift)
logpsd = np.log10(psd + psdshift)

interp = si.interp1d(logfreq, logpsd, kind="linear", fill_value="extrapolate")
interp = interpolate.interp1d(
logfreq, logpsd, kind="linear", fill_value="extrapolate"
)

loginterp_psd = interp(loginterp_freq)
interp_psd = np.power(10.0, loginterp_psd) - psdshift
Expand Down Expand Up @@ -292,17 +298,18 @@ def _exec(self, data, detectors=None, **kwargs):

# Simulate the noise matching this key
nsedata = sim_noise_timestream(
self.realization,
telescope,
self.component,
obsindx,
nse.index(key),
rate,
ob.local_index_offset + global_offset,
ob.n_local_samples,
self._oversample,
nse.freq(key),
nse.psd(key),
realization=self.realization,
telescope=telescope,
component=self.component,
obsindx=obsindx,
detindx=nse.index(key),
rate=rate,
firstsamp=ob.local_index_offset + global_offset,
samples=ob.n_local_samples,
oversample=self._oversample,
freq=nse.freq(key),
psd=nse.psd(key),
py=False,
)

# Add the noise to all detectors that have nonzero weights
Expand Down
10 changes: 6 additions & 4 deletions src/toast/observation_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,9 @@ def __getitem__(self, key):
return self._internal[key]

def __delitem__(self, key):
self._internal[key].clear()
del self._internal[key]
if key in self._internal:
self._internal[key].clear()
del self._internal[key]

def __setitem__(self, key, value):
if isinstance(value, DetectorData):
Expand Down Expand Up @@ -626,8 +627,9 @@ def __getitem__(self, key):
return self._internal[key]

def __delitem__(self, key):
self._internal[key].close()
del self._internal[key]
if key in self._internal:
self._internal[key].close()
del self._internal[key]

def __setitem__(self, key, value):
if isinstance(value, MPIShared):
Expand Down
9 changes: 7 additions & 2 deletions src/toast/tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ def create_telescope(group_size, sample_rate=10.0 * u.Hz):
while 2 * npix < group_size:
npix += 6 * ring
ring += 1
fp = fake_hexagon_focalplane(n_pix=npix)
fp = fake_hexagon_focalplane(
n_pix=npix,
sample_rate=sample_rate,
f_min=1.0e-5 * u.Hz,
f_knee=(sample_rate / 2000.0),
)
return Telescope("test", focalplane=fp)


Expand Down Expand Up @@ -147,7 +152,7 @@ def create_satellite_data(

sim_sat = ops.SimSatellite(
name="sim_sat",
n_observation=(toastcomm.ngroups * obs_per_group),
num_observations=(toastcomm.ngroups * obs_per_group),
telescope=tele,
hwp_rpm=10.0,
observation_time=obs_time,
Expand Down
58 changes: 27 additions & 31 deletions src/toast/tests/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,46 @@
# All rights reserved. Use of this source code is governed by
# a BSD-style license that can be found in the LICENSE file.

from ..mpi import MPI, use_mpi

import sys
import time

import warnings

from unittest.signals import registerResult

from unittest import TestCase
from unittest import TestResult


class MPITestCase(TestCase):
"""A simple wrapper around the standard TestCase which provides
one extra method to set the communicator.
"""
"""A simple wrapper around the standard TestCase which stores the communicator."""

def __init__(self, *args, **kwargs):
super(MPITestCase, self).__init__(*args, **kwargs)

def setComm(self, comm):
self.comm = comm
super().__init__(*args, **kwargs)
self.comm = None
if use_mpi:
self.comm = MPI.COMM_WORLD


class MPITestResult(TestResult):
"""A test result class that can print formatted text results to a stream.
The actions needed are coordinated across all processes.
Used by MPITestRunner.
"""

separator1 = "=" * 70
separator2 = "-" * 70

def __init__(self, comm, stream=None, descriptions=None, verbosity=None, **kwargs):
super(MPITestResult, self).__init__(
def __init__(self, stream=None, descriptions=None, verbosity=None, **kwargs):
super().__init__(
stream=stream, descriptions=descriptions, verbosity=verbosity, **kwargs
)
self.comm = comm
self.comm = None
if use_mpi:
self.comm = MPI.COMM_WORLD
self.stream = stream
self.descriptions = descriptions
self.buffer = False
Expand All @@ -53,8 +55,7 @@ def getDescription(self, test):
return str(test)

def startTest(self, test):
if isinstance(test, MPITestCase):
test.setComm(self.comm)
super().startTest(test)
self.stream.flush()
if self.comm is not None:
self.comm.barrier()
Expand All @@ -65,11 +66,10 @@ def startTest(self, test):
self.stream.flush()
if self.comm is not None:
self.comm.barrier()
super(MPITestResult, self).startTest(test)
return

def addSuccess(self, test):
super(MPITestResult, self).addSuccess(test)
super().addSuccess(test)
if self.comm is None:
self.stream.write("ok ")
else:
Expand All @@ -78,7 +78,7 @@ def addSuccess(self, test):
return

def addError(self, test, err):
super(MPITestResult, self).addError(test, err)
super().addError(test, err)
if self.comm is None:
self.stream.write("error ")
else:
Expand All @@ -87,7 +87,7 @@ def addError(self, test, err):
return

def addFailure(self, test, err):
super(MPITestResult, self).addFailure(test, err)
super().addFailure(test, err)
if self.comm is None:
self.stream.write("fail ")
else:
Expand All @@ -96,7 +96,7 @@ def addFailure(self, test, err):
return

def addSkip(self, test, reason):
super(MPITestResult, self).addSkip(test, reason)
super().addSkip(test, reason)
if self.comm is None:
self.stream.write("skipped({}) ".format(reason))
else:
Expand All @@ -105,7 +105,7 @@ def addSkip(self, test, reason):
return

def addExpectedFailure(self, test, err):
super(MPITestResult, self).addExpectedFailure(test, err)
super().addExpectedFailure(test, err)
if self.comm is None:
self.stream.write("expected-fail ")
else:
Expand All @@ -114,11 +114,11 @@ def addExpectedFailure(self, test, err):
return

def addUnexpectedSuccess(self, test):
super(MPITestResult, self).addUnexpectedSuccess(test)
super().addUnexpectedSuccess(test)
if self.comm is None:
self.stream.writeln("unexpected-success ")
self.stream.write("unexpected-success ")
else:
self.stream.writeln("[{}]unexpected-success ".format(self.comm.rank))
self.stream.write("[{}]unexpected-success ".format(self.comm.rank))
return

def printErrorList(self, flavour, errors):
Expand All @@ -142,15 +142,13 @@ def printErrorList(self, flavour, errors):
def printErrors(self):
if self.comm is None:
self.stream.writeln()
self.stream.flush()
self.printErrorList("ERROR", self.errors)
self.printErrorList("FAIL", self.failures)
self.stream.flush()
else:
self.comm.barrier()
if self.comm.rank == 0:
self.stream.writeln()
self.stream.flush()
for p in range(self.comm.size):
if p == self.comm.rank:
self.printErrorList("ERROR", self.errors)
Expand Down Expand Up @@ -203,15 +201,15 @@ class MPITestRunner(object):

resultclass = MPITestResult

def __init__(
self, comm, stream=None, descriptions=True, verbosity=2, warnings=None
):
def __init__(self, stream=None, descriptions=True, verbosity=2, warnings=None):
"""Construct a MPITestRunner.
Subclasses should accept **kwargs to ensure compatibility as the
interface changes.
"""
self.comm = comm
self.comm = None
if use_mpi:
self.comm = MPI.COMM_WORLD
if stream is None:
stream = sys.stderr
self.stream = _WritelnDecorator(stream)
Expand All @@ -221,9 +219,7 @@ def __init__(

def run(self, test):
"Run the given test case or test suite."
result = MPITestResult(
self.comm, self.stream, self.descriptions, self.verbosity
)
result = MPITestResult(self.stream, self.descriptions, self.verbosity)
registerResult(result)
with warnings.catch_warnings():
if self.warnings:
Expand Down
Loading

0 comments on commit 5874e56

Please sign in to comment.