Skip to content

Commit

Permalink
Refactor write methods to delegate to new Writer interface.
Browse files Browse the repository at this point in the history
Writer classes still need docs, and I haven't run any tests, but this
should exactly preserve the previous file format via FitsWriter.

Still need to do the same for reading and see that it all hangs
together, then see what to do with Output.

Writer has type annotations for now; will drop them later for
consistency.
  • Loading branch information
TallJimbo committed Jul 17, 2024
1 parent a489eef commit 007d9a8
Show file tree
Hide file tree
Showing 16 changed files with 389 additions and 178 deletions.
7 changes: 3 additions & 4 deletions piff/basis_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,19 +434,18 @@ def constant(self, value=1.):
out[0] = value # The constant term is always first.
return out

def _finish_write(self, fits, extname):
def _finish_write(self, writer):
"""Write the solution to a FITS binary table.
:param fits: An open fitsio.FITS object.
:param extname: The base name of the extension.
:param writer: A writer object that encapsulates the serialization format.
"""
if self.q is None:
raise RuntimeError("Solution not set yet. Cannot write this BasisPolynomial.")

dtypes = [ ('q', float, self.q.shape) ]
data = np.zeros(1, dtype=dtypes)
data['q'] = self.q
fits.write_table(data, extname=extname + '_solution')
writer.write_table('solution', data)

def _finish_read(self, fits, extname):
"""Read the solution from a FITS binary table.
Expand Down
17 changes: 8 additions & 9 deletions piff/convolvepsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import galsim

from .psf import PSF
from .util import write_kwargs, read_kwargs
from .util import read_kwargs
from .star import Star, StarFit
from .outliers import Outliers

Expand Down Expand Up @@ -271,11 +271,10 @@ def _getRawProfile(self, star, skip=None):
else:
return galsim.Convolve(profiles), method

def _finish_write(self, fits, extname, logger):
def _finish_write(self, writer, logger):
"""Finish the writing process with any class-specific steps.
:param fits: An open fitsio.FITS object
:param extname: The base name of the extension to write to.
:param writer: A writer object that encapsulates the serialization format.
:param logger: A logger object for logging debug info.
"""
logger = galsim.config.LoggerWrapper(logger)
Expand All @@ -285,13 +284,13 @@ def _finish_write(self, fits, extname, logger):
'dof' : self.dof,
'nremoved' : self.nremoved,
}
write_kwargs(fits, extname + '_chisq', chisq_dict)
logger.debug("Wrote the chisq info to extension %s",extname + '_chisq')
writer.write_struct('chisq', chisq_dict)
logger.debug("Wrote the chisq info to extension %s", writer.get_full_name('chisq'))
for k, comp in enumerate(self.components):
comp._write(fits, extname + '_' + str(k), logger=logger)
comp._write(writer, str(k), logger=logger)
if self.outliers:
self.outliers.write(fits, extname + '_outliers')
logger.debug("Wrote the PSF outliers to extension %s",extname + '_outliers')
self.outliers._write(writer, 'outliers')
logger.debug("Wrote the PSF outliers to extension %s", writer.get_full_name('outliers'))

def _finish_read(self, fits, extname, logger):
"""Finish the reading process with any class-specific steps.
Expand Down
7 changes: 3 additions & 4 deletions piff/gp_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,10 @@ def interpolateList(self, stars, logger=None):
fitted_stars.append(Star(star.data, fit))
return fitted_stars

def _finish_write(self, fits, extname):
def _finish_write(self, writer):
"""Finish the writing process with any class-specific steps.
:param fits: An open fitsio.FITS object
:param extname: The base name of the extension
:param writer: A writer object that encapsulates the serialization format.
"""
# Note, we're only storing the training data and hyperparameters here, which means the
# Cholesky decomposition will have to be re-computed when this object is read back from
Expand All @@ -320,7 +319,7 @@ def _finish_write(self, fits, extname):
data['ROWS'] = self.rows
data['OPTIMIZER'] = self.optimizer

fits.write_table(data, extname=extname+'_kernel')
writer.write_table('kernel', data)

def _finish_read(self, fits, extname):
"""Finish the reading process with any class-specific steps.
Expand Down
28 changes: 20 additions & 8 deletions piff/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,32 +177,44 @@ def interpolateList(self, stars, logger=None):
def write(self, fits, extname):
"""Write an Interp to a FITS file.
This method exists for backwards compatibility; subclasses should
reimplement _write or _finish_write instead.
:param fits: An open fitsio.FITS object
:param extname: The name of the extension to write the interpolator information.
"""
from .writers import FitsWriter
self._write(FitsWriter(fits, None, {}), extname)

def _write(self, writer, name):
"""Write an Interp via a Writer object.
Note: this only writes the initialization kwargs to the fits extension, not the parameters.
The base class implemenation works if the class has a self.kwargs attribute and these
The base class implementation works if the class has a self.kwargs attribute and these
are all simple values (str, float, or int).
However, the derived class will need to implement _finish_write to write the solution
parameters to a binary table.
:param fits: An open fitsio.FITS object
:param extname: The name of the extension to write the interpolator information.
:param writer: A writer object that encapsulates the serialization format.
:param name: A name to associate with this interpolator in the serialized output.
"""
# First write the basic kwargs that works for all Interp classes
interp_type = self.__class__._type_name
write_kwargs(fits, extname, dict(self.kwargs, type=interp_type))
writer.write_struct(name, dict(self.kwargs, type=interp_type))

# Now do the class-specific steps. Typically, this will write out the solution parameters.
self._finish_write(fits, extname)
with writer.nested(name) as w:
self._finish_write(w)

def _finish_write(self, fits, extname):
def _finish_write(self, writer):
"""Finish the writing process with any class-specific steps.
The base class implementation doesn't do anything, but this will probably always be
overridden by the derived class.
:param fits: An open fitsio.FITS object
:param extname: The base name of the extension
:param writer: A writer object that encapsulates the serialization format.
"""
raise NotImplementedError("Derived classes must define the _finish_write method.")

Expand Down
8 changes: 3 additions & 5 deletions piff/knn_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,12 @@ def interpolateList(self, stars, logger=None):
stars_fitted.append(Star(star.data, fit))
return stars_fitted

def _finish_write(self, fits, extname):
def _finish_write(self, writer):
"""Write the solution to a FITS binary table.
Save the knn params and the locations and targets arrays
:param fits: An open fitsio.FITS object.
:param extname: The base name of the extension with the interp information.
:param writer: A writer object that encapsulates the serialization format.
"""

dtypes = [('LOCATIONS', self.locations.dtype, self.locations.shape),
Expand All @@ -173,8 +172,7 @@ def _finish_write(self, fits, extname):
data['LOCATIONS'] = self.locations
data['TARGETS'] = self.targets

# write to fits
fits.write_table(data, extname=extname + '_solution')
writer.write_table('solution', data)

def _finish_read(self, fits, extname):
"""Read the solution from a FITS binary table.
Expand Down
10 changes: 3 additions & 7 deletions piff/mean_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,12 @@ def interpolate(self, star, logger=None):
fit = star.fit.newParams(self.mean, num=self._num)
return Star(star.data, fit)

def _finish_write(self, fits, extname):
def _finish_write(self, writer):
"""Write the solution to a FITS binary table.
:param fits: An open fitsio.FITS object.
:param extname: The base name of the extension
:param writer: A writer object that encapsulates the serialization format.
"""
cols = [ self.mean ]
dtypes = [ ('mean', float) ]
data = np.array(list(zip(*cols)), dtype=dtypes)
fits.write_table(data, extname=extname + '_solution')
writer.write_struct('solution', {'mean': self.mean})

def _finish_read(self, fits, extname):
"""Read the solution from a FITS binary table.
Expand Down
32 changes: 20 additions & 12 deletions piff/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
.. module:: model
"""

import numpy as np
import galsim

from .util import write_kwargs, read_kwargs
from .util import read_kwargs
from .star import Star


Expand Down Expand Up @@ -157,30 +154,41 @@ def draw(self, star, copy_image=True):
def write(self, fits, extname):
"""Write a Model to a FITS file.
This method exists for backwards compatibility; subclasses should
reimplement _write or _finish_write instead.
:param fits: An open fitsio.FITS object
:param extname: The name of the extension to write the model information.
"""
from .writers import FitsWriter
self._write(FitsWriter(fits, None, {}), extname)

def _write(self, writer, name):
"""Write a Model via a Writer object.
Note: this only writes the initialization kwargs to the fits extension, not the parameters.
The base class implemenation works if the class has a self.kwargs attribute and these
are all simple values (str, float, or int)
:param fits: An open fitsio.FITS object
:param extname: The name of the extension to write the model information.
:param writer: A writer object that encapsulates the serialization format.
:param name: A name to associate with this model in the serialized output.
"""
# First write the basic kwargs that works for all Model classes
model_type = self._type_name
write_kwargs(fits, extname, dict(self.kwargs, type=model_type))

writer.write_struct(name, dict(self.kwargs, type=model_type))
# Now do any class-specific steps.
self._finish_write(fits, extname)
with writer.nested(name) as w:
self._finish_write(w)

def _finish_write(self, fits, extname):
def _finish_write(self, writer):
"""Finish the writing process with any class-specific steps.
The base class implementation doesn't do anything, which is often appropriate, but
this hook exists in case any Model classes need to write extra information to the
fits file.
:param fits: An open fitsio.FITS object
:param extname: The base name of the extension
:param writer: A writer object that encapsulates the serialization format.
"""
pass

Expand Down
22 changes: 16 additions & 6 deletions piff/outliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import galsim
from scipy.stats import chi2

from .util import write_kwargs, read_kwargs
from .util import read_kwargs

class Outliers(object):
"""The base class for handling outliers.
Expand Down Expand Up @@ -97,22 +97,32 @@ def write(self, fits, extname):
:param fits: An open fitsio.FITS object
:param extname: The name of the extension to write the outliers information.
"""
from .writers import FitsWriter
self._write(FitsWriter(fits, None, {}), extname)

def _write(self, writer, name):
"""Write an Outers via a Writer object.
:param writer: A writer object that encapsulates the serialization format.
:param name: A name to associate with the Ootliers in the serialized output.
"""
# First write the basic kwargs that works for all Outliers classes
outliers_type = self._type_name
write_kwargs(fits, extname, dict(self.kwargs, type=outliers_type))
writer.write_struct(name, dict(self.kwargs, type=outliers_type))

# Now do any class-specific steps.
self._finish_write(fits, extname)
with writer.nested(name) as w:
self._finish_write(w)

def _finish_write(self, fits, extname):
def _finish_write(self, writer):
"""Finish the writing process with any class-specific steps.
The base class implementation doesn't do anything, which is often appropriate, but
this hook exists in case any Outliers classes need to write extra information to the
fits file.
:param fits: An open fitsio.FITS object
:param extname: The base name of the extension
:param writer: A writer object that encapsulates the serialization format.
:param name: A name to associate with the Ootliers in the serialized output.
"""
pass

Expand Down
7 changes: 3 additions & 4 deletions piff/polynomial_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,10 @@ def model(uv,*coeffs):
# self._unpack_coefficients
self.coeffs = coeffs

def _finish_write(self, fits, extname):
def _finish_write(self, writer):
"""Write the solution to a FITS binary table.
:param fits: An open fitsio.FITS object.
:param extname: The base name of the extension
:param writer: A writer object that encapsulates the serialization format.
"""
if self.coeffs is None:
raise RuntimeError("Coeffs not set yet. Cannot write this Polynomial.")
Expand Down Expand Up @@ -368,7 +367,7 @@ def _finish_write(self, fits, extname):

# Finally, write all of this to a FITS table.
data = np.array(list(zip(*cols)), dtype=dtypes)
fits.write_table(data, extname=extname + '_solution', header=header)
writer.write_table('solution', data, metadata=header)


def _finish_read(self, fits, extname):
Expand Down
Loading

0 comments on commit 007d9a8

Please sign in to comment.