Skip to content

Commit

Permalink
Add option to include average star/model in StarImages plots
Browse files Browse the repository at this point in the history
  • Loading branch information
rmjarvis committed Jul 3, 2024
1 parent c5a1240 commit e9d05cf
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 33 deletions.
88 changes: 63 additions & 25 deletions piff/star_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,23 @@ class StarStats(Stats):
starfit center and flux to match observed star. [default: False]
:param include_reserve: Whether to inlude reserve stars. [default: True]
:param only_reserve: Whether to skip plotting non-reserve stars. [default: False]
:param include_flaggede: Whether to include plotting flagged stars. [default: False]
:param include_flagged: Whether to include plotting flagged stars. [default: False]
:param include_ave: Whether to inlude the average image. [default: True]
:param file_name: Name of the file to output to. [default: None]
:param logger: A logger object for logging debug info. [default: None]
"""
_type_name = 'StarImages'

def __init__(self, nplot=10, adjust_stars=False,
include_reserve=True, only_reserve=False, include_flagged=False,
file_name=None, logger=None):
include_ave=True, file_name=None, logger=None):
self.nplot = nplot
self.file_name = file_name
self.adjust_stars = adjust_stars
self.include_reserve = include_reserve
self.only_reserve = only_reserve
self.include_flagged = include_flagged
self.include_ave = include_ave

def compute(self, psf, stars, logger=None):
"""
Expand All @@ -82,16 +84,48 @@ def compute(self, psf, stars, logger=None):
else:
self.indices = np.random.choice(possible_indices, self.nplot, replace=False)

logger.info("Making {0} Model Stars".format(len(self.indices)))
self.stars = []
for index in self.indices:
star = stars[index]
if self.adjust_stars:
# Do 2 passes, since we sometimes start pretty far from the right values.
star = psf.reflux(star, logger=logger)
star = psf.reflux(star, logger=logger)
self.stars.append(star)
self.models = psf.drawStarList(self.stars)
# If we need to compute the average image, then we need to reflux and drawStar for all
# possible_indices. Otherwise, only do those steps for the stars we will plot.
if self.include_ave:
calculate_indices = possible_indices
else:
calculate_indices = self.indices

logger.info("Making {0} model stars".format(len(calculate_indices)))
calculated_stars = []
calculated_models = []
for i, star in enumerate(stars):
if i in calculate_indices:
if self.adjust_stars:
# Do 2 passes, since we sometimes start pretty far from the right values.
star = psf.reflux(star, logger=logger)
star = psf.reflux(star, logger=logger)
calculated_stars.append(star)
calculated_models.append(psf.drawStar(star))
else:
calculated_stars.append(None)
calculated_models.append(None)

# if including the average image, put that first.
logger.info("Making average star and model")
if self.include_ave:
ave_star_image = np.mean([s.image.array for s in calculated_stars if s is not None],
axis=0)
ave_model_image = np.mean([s.image.array for s in calculated_models if s is not None],
axis=0)
ave_star_image = galsim.Image(ave_star_image)
ave_model_image = galsim.Image(ave_model_image)
ave_star = Star(stars[0].data.withNew(image=ave_star_image), None)
ave_model = Star(stars[0].data.withNew(image=ave_model_image), None)
self.stars = [ave_star]
self.models = [ave_model]
self.stars.extend([calculated_stars[i] for i in self.indices])
self.models.extend([calculated_models[i] for i in self.indices])
self.indices = [-1] + self.indices
else:
self.stars = [calculated_stars[i] for i in self.indices]
self.models = [calculated_models[i] for i in self.indices]


def plot(self, logger=None, **kwargs):
r"""Make the plots.
Expand All @@ -115,25 +149,29 @@ def plot(self, logger=None, **kwargs):

logger.info("Creating %d Star plots", self.nplot)

for i in range(len(self.indices)):
for i in range(nplot):
star = self.stars[i]
model = self.models[i]

# get index, u, v coordinates to put in title
u = star.data.properties['u']
v = star.data.properties['v']
index = self.indices[i]

ii = i // 2
jj = (i % 2) * 3

title = f'Star {index}'
if star.is_reserve:
title = 'Reserve ' + title
if star.is_flagged:
title = 'Flagged ' + title
axs[ii][jj+0].set_title(title)
axs[ii][jj+1].set_title(f'PSF at (u,v) = \n ({u:+.02e}, {v:+.02e})')
if self.include_ave and i == 0:
axs[ii][jj+0].set_title('Average Star')
axs[ii][jj+1].set_title('Average PSF')
else:
# get index, u, v coordinates to put in title
index = self.indices[i]
u = star.data.properties['u']
v = star.data.properties['v']

title = f'Star {index}'
if star.is_reserve:
title = 'Reserve ' + title
if star.is_flagged:
title = 'Flagged ' + title
axs[ii][jj+0].set_title(title)
axs[ii][jj+1].set_title(f'PSF at (u,v) = \n ({u:+.02e}, {v:+.02e})')
axs[ii][jj+2].set_title('Star - PSF')

star_image = star.image
Expand Down
30 changes: 22 additions & 8 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ def test_starstats_config():
'file_name': star_file,
'nplot': 5,
'adjust_stars': True,
'include_ave': False,
}
]
}
Expand All @@ -550,7 +551,7 @@ def test_starstats_config():

# check default nplot
psf = piff.read(psf_file)
starStats = piff.StarStats()
starStats = piff.StarStats(include_ave=False)
orig_stars, wcs, pointing = piff.Input.process(config['input'], logger=logger)
orig_stars = piff.Select.process(config['select'], orig_stars, logger=logger)
with np.testing.assert_raises(RuntimeError):
Expand All @@ -563,12 +564,20 @@ def test_starstats_config():
orig_stars[starStats.indices[2]].image.array)

# check nplot = 6
starStats = piff.StarStats(nplot=6)
starStats = piff.StarStats(nplot=6, include_ave=False)
starStats.compute(psf, orig_stars)
assert len(starStats.stars) == 6

starStats = piff.StarStats(nplot=6, include_ave=True)
starStats.compute(psf, orig_stars)
assert len(starStats.stars) == 7

starStats = piff.StarStats(nplot=6) # include_ave=True is the default
starStats.compute(psf, orig_stars)
assert len(starStats.stars) == 7

# check nplot >> len(stars)
starStats = piff.StarStats(nplot=1000000)
starStats = piff.StarStats(nplot=1000000, include_ave=False)
starStats.compute(psf, orig_stars)
assert len(starStats.stars) == len(orig_stars)
# if use all stars, no randomness
Expand All @@ -577,7 +586,7 @@ def test_starstats_config():
starStats.plot() # Make sure this runs without error and in finite time.

# check nplot = 0
starStats = piff.StarStats(nplot=0)
starStats = piff.StarStats(nplot=0, include_ave=False)
starStats.compute(psf, orig_stars)
assert len(starStats.stars) == len(orig_stars)
# if use all stars, no randomness
Expand All @@ -588,20 +597,25 @@ def test_starstats_config():
# With include_reserve=False, only 8 stars
print('All stars: n=',len(starStats.stars)) # 10 stars total
assert len(starStats.stars) == 10
starStats = piff.StarStats(nplot=0, include_reserve=False)
starStats = piff.StarStats(nplot=0, include_reserve=False, include_ave=False)
starStats.compute(psf, orig_stars)
assert len(starStats.stars) == 8
starStats.plot() # Make sure this runs without error.

# With only_reserve=True, only 2 stars
starStats = piff.StarStats(nplot=0, only_reserve=True)
starStats = piff.StarStats(nplot=0, only_reserve=True, include_ave=False)
starStats.compute(psf, orig_stars)
assert len(starStats.stars) == 2
starStats.plot() # Make sure this runs without error.

starStats = piff.StarStats(nplot=0, only_reserve=True)
starStats.compute(psf, orig_stars)
assert len(starStats.stars) == 3
starStats.plot() # Make sure this runs without error.

# rerun with adjust stars and see if it did the right thing
# first with adjust_stars == False
starStats = piff.StarStats(nplot=0, adjust_stars=False)
starStats = piff.StarStats(nplot=0, adjust_stars=False, include_ave=False)
starStats.compute(psf, orig_stars, logger=logger)
fluxs_noadjust = np.array([s.fit.flux for s in starStats.stars])
ds_noadjust = np.array([s.fit.center for s in starStats.stars])
Expand All @@ -611,7 +625,7 @@ def test_starstats_config():
np.testing.assert_array_equal(ds_noadjust, 0)

# now with adjust_stars == True
starStats = piff.StarStats(nplot=0, adjust_stars=True)
starStats = piff.StarStats(nplot=0, adjust_stars=True, include_ave=False)
starStats.compute(psf, orig_stars, logger=logger)
fluxs_adjust = np.array([s.fit.flux for s in starStats.stars])
ds_adjust = np.array([s.fit.center for s in starStats.stars])
Expand Down

0 comments on commit e9d05cf

Please sign in to comment.