Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the first composite type: Sum #157

Merged
merged 12 commits into from
Apr 16, 2024
3 changes: 3 additions & 0 deletions piff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@
# Optics
from .optical_model import Optical, optical_templates

# Composite PSFs
from .sumpsf import SumPSF

# Leave these in their own namespaces
from . import util
from . import des
Expand Down
2 changes: 1 addition & 1 deletion piff/gsobject_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def fit(self, star, fastfit=None, logger=None, convert_func=None):
if self._centered:
center = (du, dv)
else:
center = (0.0, 0.0)
center = star.fit.center
params = [du, dv] + params
params_var = np.concatenate([var[1:3], params_var])
if self._fit_flux:
Expand Down
42 changes: 28 additions & 14 deletions piff/psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def set_num(self, num):
# But this is the minimum action that all subclasses need to do.
self._num = num

@property
def num_components(self):
# Subclasses for which this is not true can overwrite this
return 1

@classmethod
def __init_subclass__(cls):
# Classes that don't want to register a type name can either not define _type_name
Expand Down Expand Up @@ -694,7 +699,12 @@ def _drawStar(self, star, center=None):
raise NotImplementedError("Derived classes must define the _drawStar function")

def _getProfile(self, star):
raise NotImplementedError("Derived classes must define the _getProfile function")
prof, method = self._getRawProfile(star)
prof = prof.shift(star.fit.center) * star.fit.flux
return prof, method

def _getRawProfile(self, star):
raise NotImplementedError("Derived classes must define the _getRawProfile function")

def write(self, file_name, logger=None):
"""Write a PSF object to a file.
Expand Down Expand Up @@ -723,10 +733,12 @@ def _write(self, fits, extname, logger=None):
psf_type = self._type_name
write_kwargs(fits, extname, dict(self.kwargs, type=psf_type, piff_version=piff_version))
logger.info("Wrote the basic PSF information to extname %s", extname)
Star.write(self.stars, fits, extname=extname + '_stars')
logger.info("Wrote the PSF stars to extname %s", extname + '_stars')
self.writeWCS(fits, extname=extname + '_wcs', logger=logger)
logger.info("Wrote the PSF WCS to extname %s", extname + '_wcs')
if hasattr(self, 'stars'):
Star.write(self.stars, fits, extname=extname + '_stars')
logger.info("Wrote the PSF stars to extname %s", extname + '_stars')
if hasattr(self, 'wcs'):
self.writeWCS(fits, extname=extname + '_wcs', logger=logger)
logger.info("Wrote the PSF WCS to extname %s", extname + '_wcs')
self._finish_write(fits, extname=extname, logger=logger)

@classmethod
Expand Down Expand Up @@ -773,17 +785,19 @@ def _read(cls, fits, extname, logger):
raise ValueError("psf type %s is not a valid Piff PSF"%psf_type)
psf_cls = PSF.valid_psf_types[psf_type]

# Read the stars, wcs, pointing values
stars = Star.read(fits, extname + '_stars')
logger.debug("stars = %s",stars)
wcs, pointing = cls.readWCS(fits, extname + '_wcs', logger=logger)
logger.debug("wcs = %s, pointing = %s",wcs,pointing)

# Make the PSF instance
psf = psf_cls(**kwargs)
psf.stars = stars
psf.wcs = wcs
psf.pointing = pointing

# Read the stars, wcs, pointing values
if extname + '_stars' in fits:
stars = Star.read(fits, extname + '_stars')
logger.debug("stars = %s",stars)
psf.stars = stars
if extname + '_wcs' in fits:
wcs, pointing = cls.readWCS(fits, extname + '_wcs', logger=logger)
logger.debug("wcs = %s, pointing = %s",wcs,pointing)
psf.wcs = wcs
psf.pointing = pointing

# Just in case the class needs to do something else at the end.
psf._finish_read(fits, extname, logger)
Expand Down
5 changes: 0 additions & 5 deletions piff/simplepsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,6 @@ def interpolateStar(self, star):
def _drawStar(self, star, center=None):
return self.model.draw(star, center=center)

def _getProfile(self, star):
prof, method = self._getRawProfile(star)
prof = prof.shift(star.fit.center) * star.fit.flux
return prof, method

def _getRawProfile(self, star):
return self.model.getProfile(star.fit.get_params(self._num)), self.model._method

Expand Down
2 changes: 1 addition & 1 deletion piff/star.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def write(self, stars, fits, extname):

# params might need to be flattened
if stars[0].fit.params_lens is not None:
header = {'PARAMS_LENS' : stars[0].fit.params_lens}
header = {'PARAMS_LENS' : str(stars[0].fit.params_lens)}
params = [ StarFit.flatten_params(p) for p in params ]

dtypes.append( ('params', float, len(params[0])) )
Expand Down
Loading
Loading