Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the first composite type: Sum #157

Merged
merged 12 commits into from
Apr 16, 2024
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ jobs:
# Extra packages needed for testing
# Note: Pin pillow <10 until this bug is fixed:
# https://github.com/python-pillow/Pillow/issues/7259
pip install -U nose coverage pytest nbval ipykernel "pillow<10"
pip install -U nose coverage "pytest<8" nbval ipykernel "pillow<10"

- name: Install Pixmappy (not on pip)
run: |
Expand Down
3 changes: 3 additions & 0 deletions piff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@
# Optics
from .optical_model import Optical, optical_templates

# Composite PSFs
from .sumpsf import SumPSF

# Leave these in their own namespaces
from . import util
from . import des
Expand Down
2 changes: 1 addition & 1 deletion piff/gsobject_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def fit(self, star, fastfit=None, logger=None, convert_func=None):
if self._centered:
center = (du, dv)
else:
center = (0.0, 0.0)
center = star.fit.center
params = [du, dv] + params
params_var = np.concatenate([var[1:3], params_var])
if self._fit_flux:
Expand Down
2 changes: 1 addition & 1 deletion piff/pixelgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def chisq(self, star, logger=None, convert_func=None):
v = v[mask] / self.scale
ui,uf = np.divmod(u,1)
vi,vf = np.divmod(v,1)
xr = self.interp.xrange
xr = int(np.ceil(self.interp.xrange))
# Note arguments are basis pixel position minus image pixel position.
# Hence the minus sign in front of uf.
argu = -uf[:,np.newaxis] + np.arange(-xr+1,xr+1)[np.newaxis,:]
Expand Down
42 changes: 28 additions & 14 deletions piff/psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def set_num(self, num):
# But this is the minimum action that all subclasses need to do.
self._num = num

@property
def num_components(self):
# Subclasses for which this is not true can overwrite this
return 1

@classmethod
def __init_subclass__(cls):
# Classes that don't want to register a type name can either not define _type_name
Expand Down Expand Up @@ -694,7 +699,12 @@ def _drawStar(self, star, center=None):
raise NotImplementedError("Derived classes must define the _drawStar function")

def _getProfile(self, star):
raise NotImplementedError("Derived classes must define the _getProfile function")
prof, method = self._getRawProfile(star)
prof = prof.shift(star.fit.center) * star.fit.flux
return prof, method

def _getRawProfile(self, star):
raise NotImplementedError("Derived classes must define the _getRawProfile function")

def write(self, file_name, logger=None):
"""Write a PSF object to a file.
Expand Down Expand Up @@ -723,10 +733,12 @@ def _write(self, fits, extname, logger=None):
psf_type = self._type_name
write_kwargs(fits, extname, dict(self.kwargs, type=psf_type, piff_version=piff_version))
logger.info("Wrote the basic PSF information to extname %s", extname)
Star.write(self.stars, fits, extname=extname + '_stars')
logger.info("Wrote the PSF stars to extname %s", extname + '_stars')
self.writeWCS(fits, extname=extname + '_wcs', logger=logger)
logger.info("Wrote the PSF WCS to extname %s", extname + '_wcs')
if hasattr(self, 'stars'):
Star.write(self.stars, fits, extname=extname + '_stars')
logger.info("Wrote the PSF stars to extname %s", extname + '_stars')
if hasattr(self, 'wcs'):
self.writeWCS(fits, extname=extname + '_wcs', logger=logger)
logger.info("Wrote the PSF WCS to extname %s", extname + '_wcs')
self._finish_write(fits, extname=extname, logger=logger)

@classmethod
Expand Down Expand Up @@ -773,17 +785,19 @@ def _read(cls, fits, extname, logger):
raise ValueError("psf type %s is not a valid Piff PSF"%psf_type)
psf_cls = PSF.valid_psf_types[psf_type]

# Read the stars, wcs, pointing values
stars = Star.read(fits, extname + '_stars')
logger.debug("stars = %s",stars)
wcs, pointing = cls.readWCS(fits, extname + '_wcs', logger=logger)
logger.debug("wcs = %s, pointing = %s",wcs,pointing)

# Make the PSF instance
psf = psf_cls(**kwargs)
psf.stars = stars
psf.wcs = wcs
psf.pointing = pointing

# Read the stars, wcs, pointing values
if extname + '_stars' in fits:
stars = Star.read(fits, extname + '_stars')
logger.debug("stars = %s",stars)
psf.stars = stars
if extname + '_wcs' in fits:
wcs, pointing = cls.readWCS(fits, extname + '_wcs', logger=logger)
logger.debug("wcs = %s, pointing = %s",wcs,pointing)
psf.wcs = wcs
psf.pointing = pointing

# Just in case the class needs to do something else at the end.
psf._finish_read(fits, extname, logger)
Expand Down
5 changes: 0 additions & 5 deletions piff/simplepsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,6 @@ def interpolateStar(self, star):
def _drawStar(self, star, center=None):
return self.model.draw(star, center=center)

def _getProfile(self, star):
prof, method = self._getRawProfile(star)
prof = prof.shift(star.fit.center) * star.fit.flux
return prof, method

def _getRawProfile(self, star):
return self.model.getProfile(star.fit.get_params(self._num)), self.model._method

Expand Down
2 changes: 1 addition & 1 deletion piff/star.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def write(self, stars, fits, extname):

# params might need to be flattened
if stars[0].fit.params_lens is not None:
header = {'PARAMS_LENS' : stars[0].fit.params_lens}
header = {'PARAMS_LENS' : str(stars[0].fit.params_lens)}
params = [ StarFit.flatten_params(p) for p in params ]

dtypes.append( ('params', float, len(params[0])) )
Expand Down
36 changes: 29 additions & 7 deletions piff/star_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,23 @@ class StarStats(Stats):
(without replacement). [default: 10]
:param adjust_stars: Boolean. If true, when computing, will also fit for best
starfit center and flux to match observed star. [default: False]
:param include_reserve: Whether to inlude reserve stars. [default: True]
:param only_reserve: Whether to skip plotting non-reserve stars. [default: False]
:param include_flaggede: Whether to include plotting flagged stars. [default: False]
:param file_name: Name of the file to output to. [default: None]
:param logger: A logger object for logging debug info. [default: None]
"""
_type_name = 'StarImages'

def __init__(self, nplot=10, adjust_stars=False, file_name=None, logger=None):
def __init__(self, nplot=10, adjust_stars=False,
include_reserve=True, only_reserve=False, include_flagged=False,
file_name=None, logger=None):
self.nplot = nplot
self.file_name = file_name
self.adjust_stars = adjust_stars
self.include_reserve = include_reserve
self.only_reserve = only_reserve
self.include_flagged = include_flagged

def compute(self, psf, stars, logger=None):
"""
Expand All @@ -58,12 +66,21 @@ def compute(self, psf, stars, logger=None):
:param logger: A logger object for logging debug info. [default: None]
"""
logger = galsim.config.LoggerWrapper(logger)
# get the shapes
# Determine which stars to plot
possible_indices = []
if self.include_reserve:
possible_indices += [i for i,s in enumerate(stars)
if s.is_reserve and (self.include_flagged or not s.is_flagged)]
if not self.only_reserve:
possible_indices += [i for i,s in enumerate(stars)
if not s.is_reserve and (self.include_flagged or not s.is_flagged)]
possible_indices = sorted(possible_indices)

if self.nplot == 0 or self.nplot >= len(stars):
# select all stars
self.indices = np.arange(len(stars))
# select all viable stars
self.indices = possible_indices
else:
self.indices = np.random.choice(len(stars), self.nplot, replace=False)
self.indices = np.random.choice(possible_indices, self.nplot, replace=False)

logger.info("Making {0} Model Stars".format(len(self.indices)))
self.stars = []
Expand Down Expand Up @@ -110,8 +127,13 @@ def plot(self, logger=None, **kwargs):
ii = i // 2
jj = (i % 2) * 3

axs[ii][jj+0].set_title('Star {0}'.format(index))
axs[ii][jj+1].set_title('PSF at (u,v) = \n ({0:+.02e}, {1:+.02e})'.format(u, v))
title = f'Star {index}'
if star.is_reserve:
title = 'Reserve ' + title
if star.is_flagged:
title = 'Flagged ' + title
axs[ii][jj+0].set_title(title)
axs[ii][jj+1].set_title(f'PSF at (u,v) = \n ({u:+.02e}, {v:+.02e})')
axs[ii][jj+2].set_title('Star - PSF')

star_image = star.image
Expand Down
Loading
Loading