Skip to content

Commit 63a50b0

Browse files
committed
Restore noise simulation unit tests.
1 parent 2b68ca8 commit 63a50b0

File tree

7 files changed

+396
-445
lines changed

7 files changed

+396
-445
lines changed

src/toast/data.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ def __repr__(self):
5151
val += "\n>"
5252
return val
5353

54+
def __del__(self):
55+
if hasattr(self, "obs"):
56+
self.clear()
57+
5458
@property
5559
def comm(self):
5660
"""The toast.Comm over which the data is distributed."""

src/toast/future_ops/sim_tod_noise.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66

77
import numpy as np
88

9+
from scipy import interpolate
10+
11+
from .. import rng
12+
913
from ..timing import function_timer
1014

1115
from ..traits import trait_docs, Int, Unicode
@@ -21,17 +25,17 @@
2125

2226
@function_timer
2327
def sim_noise_timestream(
24-
realization,
25-
telescope,
26-
component,
27-
obsindx,
28-
detindx,
29-
rate,
30-
firstsamp,
31-
samples,
32-
oversample,
33-
freq,
34-
psd,
28+
realization=0,
29+
telescope=0,
30+
component=0,
31+
obsindx=0,
32+
detindx=0,
33+
rate=1.0,
34+
firstsamp=0,
35+
samples=0,
36+
oversample=2,
37+
freq=None,
38+
psd=None,
3539
py=False,
3640
):
3741
"""Generate a noise timestream, given a starting RNG state.
@@ -124,7 +128,9 @@ def sim_noise_timestream(
124128
logfreq = np.log10(freq + freqshift)
125129
logpsd = np.log10(psd + psdshift)
126130

127-
interp = si.interp1d(logfreq, logpsd, kind="linear", fill_value="extrapolate")
131+
interp = interpolate.interp1d(
132+
logfreq, logpsd, kind="linear", fill_value="extrapolate"
133+
)
128134

129135
loginterp_psd = interp(loginterp_freq)
130136
interp_psd = np.power(10.0, loginterp_psd) - psdshift
@@ -292,17 +298,18 @@ def _exec(self, data, detectors=None, **kwargs):
292298

293299
# Simulate the noise matching this key
294300
nsedata = sim_noise_timestream(
295-
self.realization,
296-
telescope,
297-
self.component,
298-
obsindx,
299-
nse.index(key),
300-
rate,
301-
ob.local_index_offset + global_offset,
302-
ob.n_local_samples,
303-
self._oversample,
304-
nse.freq(key),
305-
nse.psd(key),
301+
realization=self.realization,
302+
telescope=telescope,
303+
component=self.component,
304+
obsindx=obsindx,
305+
detindx=nse.index(key),
306+
rate=rate,
307+
firstsamp=ob.local_index_offset + global_offset,
308+
samples=ob.n_local_samples,
309+
oversample=self._oversample,
310+
freq=nse.freq(key),
311+
psd=nse.psd(key),
312+
py=False,
306313
)
307314

308315
# Add the noise to all detectors that have nonzero weights

src/toast/observation_data.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,9 @@ def __getitem__(self, key):
377377
return self._internal[key]
378378

379379
def __delitem__(self, key):
380-
self._internal[key].clear()
381-
del self._internal[key]
380+
if key in self._internal:
381+
self._internal[key].clear()
382+
del self._internal[key]
382383

383384
def __setitem__(self, key, value):
384385
if isinstance(value, DetectorData):
@@ -626,8 +627,9 @@ def __getitem__(self, key):
626627
return self._internal[key]
627628

628629
def __delitem__(self, key):
629-
self._internal[key].close()
630-
del self._internal[key]
630+
if key in self._internal:
631+
self._internal[key].close()
632+
del self._internal[key]
631633

632634
def __setitem__(self, key, value):
633635
if isinstance(value, MPIShared):

src/toast/tests/_helpers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,12 @@ def create_telescope(group_size, sample_rate=10.0 * u.Hz):
8787
while 2 * npix < group_size:
8888
npix += 6 * ring
8989
ring += 1
90-
fp = fake_hexagon_focalplane(n_pix=npix)
90+
fp = fake_hexagon_focalplane(
91+
n_pix=npix,
92+
sample_rate=sample_rate,
93+
f_min=1.0e-5 * u.Hz,
94+
f_knee=(sample_rate / 2000.0),
95+
)
9196
return Telescope("test", focalplane=fp)
9297

9398

@@ -147,7 +152,7 @@ def create_satellite_data(
147152

148153
sim_sat = ops.SimSatellite(
149154
name="sim_sat",
150-
n_observation=(toastcomm.ngroups * obs_per_group),
155+
num_observations=(toastcomm.ngroups * obs_per_group),
151156
telescope=tele,
152157
hwp_rpm=10.0,
153158
observation_time=obs_time,

src/toast/tests/mpi.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,44 +2,46 @@
22
# All rights reserved. Use of this source code is governed by
33
# a BSD-style license that can be found in the LICENSE file.
44

5+
from ..mpi import MPI, use_mpi
6+
57
import sys
68
import time
79

810
import warnings
911

1012
from unittest.signals import registerResult
13+
1114
from unittest import TestCase
1215
from unittest import TestResult
1316

1417

1518
class MPITestCase(TestCase):
16-
"""A simple wrapper around the standard TestCase which provides
17-
one extra method to set the communicator.
18-
"""
19+
"""A simple wrapper around the standard TestCase which stores the communicator."""
1920

2021
def __init__(self, *args, **kwargs):
21-
super(MPITestCase, self).__init__(*args, **kwargs)
22-
23-
def setComm(self, comm):
24-
self.comm = comm
22+
super().__init__(*args, **kwargs)
23+
self.comm = None
24+
if use_mpi:
25+
self.comm = MPI.COMM_WORLD
2526

2627

2728
class MPITestResult(TestResult):
2829
"""A test result class that can print formatted text results to a stream.
2930
3031
The actions needed are coordinated across all processes.
3132
32-
Used by MPITestRunner.
3333
"""
3434

3535
separator1 = "=" * 70
3636
separator2 = "-" * 70
3737

38-
def __init__(self, comm, stream=None, descriptions=None, verbosity=None, **kwargs):
39-
super(MPITestResult, self).__init__(
38+
def __init__(self, stream=None, descriptions=None, verbosity=None, **kwargs):
39+
super().__init__(
4040
stream=stream, descriptions=descriptions, verbosity=verbosity, **kwargs
4141
)
42-
self.comm = comm
42+
self.comm = None
43+
if use_mpi:
44+
self.comm = MPI.COMM_WORLD
4345
self.stream = stream
4446
self.descriptions = descriptions
4547
self.buffer = False
@@ -53,8 +55,7 @@ def getDescription(self, test):
5355
return str(test)
5456

5557
def startTest(self, test):
56-
if isinstance(test, MPITestCase):
57-
test.setComm(self.comm)
58+
super().startTest(test)
5859
self.stream.flush()
5960
if self.comm is not None:
6061
self.comm.barrier()
@@ -65,11 +66,10 @@ def startTest(self, test):
6566
self.stream.flush()
6667
if self.comm is not None:
6768
self.comm.barrier()
68-
super(MPITestResult, self).startTest(test)
6969
return
7070

7171
def addSuccess(self, test):
72-
super(MPITestResult, self).addSuccess(test)
72+
super().addSuccess(test)
7373
if self.comm is None:
7474
self.stream.write("ok ")
7575
else:
@@ -78,7 +78,7 @@ def addSuccess(self, test):
7878
return
7979

8080
def addError(self, test, err):
81-
super(MPITestResult, self).addError(test, err)
81+
super().addError(test, err)
8282
if self.comm is None:
8383
self.stream.write("error ")
8484
else:
@@ -87,7 +87,7 @@ def addError(self, test, err):
8787
return
8888

8989
def addFailure(self, test, err):
90-
super(MPITestResult, self).addFailure(test, err)
90+
super().addFailure(test, err)
9191
if self.comm is None:
9292
self.stream.write("fail ")
9393
else:
@@ -96,7 +96,7 @@ def addFailure(self, test, err):
9696
return
9797

9898
def addSkip(self, test, reason):
99-
super(MPITestResult, self).addSkip(test, reason)
99+
super().addSkip(test, reason)
100100
if self.comm is None:
101101
self.stream.write("skipped({}) ".format(reason))
102102
else:
@@ -105,7 +105,7 @@ def addSkip(self, test, reason):
105105
return
106106

107107
def addExpectedFailure(self, test, err):
108-
super(MPITestResult, self).addExpectedFailure(test, err)
108+
super().addExpectedFailure(test, err)
109109
if self.comm is None:
110110
self.stream.write("expected-fail ")
111111
else:
@@ -114,11 +114,11 @@ def addExpectedFailure(self, test, err):
114114
return
115115

116116
def addUnexpectedSuccess(self, test):
117-
super(MPITestResult, self).addUnexpectedSuccess(test)
117+
super().addUnexpectedSuccess(test)
118118
if self.comm is None:
119-
self.stream.writeln("unexpected-success ")
119+
self.stream.write("unexpected-success ")
120120
else:
121-
self.stream.writeln("[{}]unexpected-success ".format(self.comm.rank))
121+
self.stream.write("[{}]unexpected-success ".format(self.comm.rank))
122122
return
123123

124124
def printErrorList(self, flavour, errors):
@@ -142,15 +142,13 @@ def printErrorList(self, flavour, errors):
142142
def printErrors(self):
143143
if self.comm is None:
144144
self.stream.writeln()
145-
self.stream.flush()
146145
self.printErrorList("ERROR", self.errors)
147146
self.printErrorList("FAIL", self.failures)
148147
self.stream.flush()
149148
else:
150149
self.comm.barrier()
151150
if self.comm.rank == 0:
152151
self.stream.writeln()
153-
self.stream.flush()
154152
for p in range(self.comm.size):
155153
if p == self.comm.rank:
156154
self.printErrorList("ERROR", self.errors)
@@ -203,15 +201,15 @@ class MPITestRunner(object):
203201

204202
resultclass = MPITestResult
205203

206-
def __init__(
207-
self, comm, stream=None, descriptions=True, verbosity=2, warnings=None
208-
):
204+
def __init__(self, stream=None, descriptions=True, verbosity=2, warnings=None):
209205
"""Construct a MPITestRunner.
210206
211207
Subclasses should accept **kwargs to ensure compatibility as the
212208
interface changes.
213209
"""
214-
self.comm = comm
210+
self.comm = None
211+
if use_mpi:
212+
self.comm = MPI.COMM_WORLD
215213
if stream is None:
216214
stream = sys.stderr
217215
self.stream = _WritelnDecorator(stream)
@@ -221,9 +219,7 @@ def __init__(
221219

222220
def run(self, test):
223221
"Run the given test case or test suite."
224-
result = MPITestResult(
225-
self.comm, self.stream, self.descriptions, self.verbosity
226-
)
222+
result = MPITestResult(self.stream, self.descriptions, self.verbosity)
227223
registerResult(result)
228224
with warnings.catch_warnings():
229225
if self.warnings:

0 commit comments

Comments
 (0)