From 007d9a85fd5a383199779d39300fdcd87f8db0bf Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Wed, 17 Jul 2024 17:13:10 -0400 Subject: [PATCH] Refactor write methods to delegate to new Writer interface. Writer classes still need docs, and I haven't run any tests, but this should exactly preserve the previous file format via FitsWriter. Still need to do the same for reading and see that it all hangs together, then see what to do with Output. Writer has type annotations for now; will drop them later for consistency. --- piff/basis_interp.py | 7 +- piff/convolvepsf.py | 17 ++- piff/gp_interp.py | 7 +- piff/interp.py | 28 +++-- piff/knn_interp.py | 8 +- piff/mean_interp.py | 10 +- piff/model.py | 32 +++-- piff/outliers.py | 22 +++- piff/polynomial_interp.py | 7 +- piff/psf.py | 85 +++---------- piff/simplepsf.py | 21 ++-- piff/singlechip.py | 13 +- piff/star.py | 15 ++- piff/sumpsf.py | 15 ++- piff/util.py | 23 ---- piff/writers.py | 257 ++++++++++++++++++++++++++++++++++++++ 16 files changed, 389 insertions(+), 178 deletions(-) create mode 100644 piff/writers.py diff --git a/piff/basis_interp.py b/piff/basis_interp.py index bd02f8c1..b7cc5460 100644 --- a/piff/basis_interp.py +++ b/piff/basis_interp.py @@ -434,11 +434,10 @@ def constant(self, value=1.): out[0] = value # The constant term is always first. return out - def _finish_write(self, fits, extname): + def _finish_write(self, writer): """Write the solution to a FITS binary table. - :param fits: An open fitsio.FITS object. - :param extname: The base name of the extension. + :param writer: A writer object that encapsulates the serialization format. """ if self.q is None: raise RuntimeError("Solution not set yet. Cannot write this BasisPolynomial.") @@ -446,7 +445,7 @@ def _finish_write(self, fits, extname): dtypes = [ ('q', float, self.q.shape) ] data = np.zeros(1, dtype=dtypes) data['q'] = self.q - fits.write_table(data, extname=extname + '_solution') + writer.write_table('solution', data) def _finish_read(self, fits, extname): """Read the solution from a FITS binary table. diff --git a/piff/convolvepsf.py b/piff/convolvepsf.py index cabf1503..22e8f866 100644 --- a/piff/convolvepsf.py +++ b/piff/convolvepsf.py @@ -20,7 +20,7 @@ import galsim from .psf import PSF -from .util import write_kwargs, read_kwargs +from .util import read_kwargs from .star import Star, StarFit from .outliers import Outliers @@ -271,11 +271,10 @@ def _getRawProfile(self, star, skip=None): else: return galsim.Convolve(profiles), method - def _finish_write(self, fits, extname, logger): + def _finish_write(self, writer, logger): """Finish the writing process with any class-specific steps. - :param fits: An open fitsio.FITS object - :param extname: The base name of the extension to write to. + :param writer: A writer object that encapsulates the serialization format. :param logger: A logger object for logging debug info. """ logger = galsim.config.LoggerWrapper(logger) @@ -285,13 +284,13 @@ def _finish_write(self, fits, extname, logger): 'dof' : self.dof, 'nremoved' : self.nremoved, } - write_kwargs(fits, extname + '_chisq', chisq_dict) - logger.debug("Wrote the chisq info to extension %s",extname + '_chisq') + writer.write_struct('chisq', chisq_dict) + logger.debug("Wrote the chisq info to extension %s", writer.get_full_name('chisq')) for k, comp in enumerate(self.components): - comp._write(fits, extname + '_' + str(k), logger=logger) + comp._write(writer, str(k), logger=logger) if self.outliers: - self.outliers.write(fits, extname + '_outliers') - logger.debug("Wrote the PSF outliers to extension %s",extname + '_outliers') + self.outliers._write(writer, 'outliers') + logger.debug("Wrote the PSF outliers to extension %s", writer.get_full_name('outliers')) def _finish_read(self, fits, extname, logger): """Finish the reading process with any class-specific steps. diff --git a/piff/gp_interp.py b/piff/gp_interp.py index 8b33c86c..080ef462 100644 --- a/piff/gp_interp.py +++ b/piff/gp_interp.py @@ -292,11 +292,10 @@ def interpolateList(self, stars, logger=None): fitted_stars.append(Star(star.data, fit)) return fitted_stars - def _finish_write(self, fits, extname): + def _finish_write(self, writer): """Finish the writing process with any class-specific steps. - :param fits: An open fitsio.FITS object - :param extname: The base name of the extension + :param writer: A writer object that encapsulates the serialization format. """ # Note, we're only storing the training data and hyperparameters here, which means the # Cholesky decomposition will have to be re-computed when this object is read back from @@ -320,7 +319,7 @@ def _finish_write(self, fits, extname): data['ROWS'] = self.rows data['OPTIMIZER'] = self.optimizer - fits.write_table(data, extname=extname+'_kernel') + writer.write_table('kernel', data) def _finish_read(self, fits, extname): """Finish the reading process with any class-specific steps. diff --git a/piff/interp.py b/piff/interp.py index 22cde835..7747eb99 100644 --- a/piff/interp.py +++ b/piff/interp.py @@ -177,32 +177,44 @@ def interpolateList(self, stars, logger=None): def write(self, fits, extname): """Write an Interp to a FITS file. + This method exists for backwards compatibility; subclasses should + reimplement _write or _finish_write instead. + + :param fits: An open fitsio.FITS object + :param extname: The name of the extension to write the interpolator information. + """ + from .writers import FitsWriter + self._write(FitsWriter(fits, None, {}), extname) + + def _write(self, writer, name): + """Write an Interp via a Writer object. + Note: this only writes the initialization kwargs to the fits extension, not the parameters. - The base class implemenation works if the class has a self.kwargs attribute and these + The base class implementation works if the class has a self.kwargs attribute and these are all simple values (str, float, or int). However, the derived class will need to implement _finish_write to write the solution parameters to a binary table. - :param fits: An open fitsio.FITS object - :param extname: The name of the extension to write the interpolator information. + :param writer: A writer object that encapsulates the serialization format. + :param name: A name to associate with this interpolator in the serialized output. """ # First write the basic kwargs that works for all Interp classes interp_type = self.__class__._type_name - write_kwargs(fits, extname, dict(self.kwargs, type=interp_type)) + writer.write_struct(name, dict(self.kwargs, type=interp_type)) # Now do the class-specific steps. Typically, this will write out the solution parameters. - self._finish_write(fits, extname) + with writer.nested(name) as w: + self._finish_write(w) - def _finish_write(self, fits, extname): + def _finish_write(self, writer): """Finish the writing process with any class-specific steps. The base class implementation doesn't do anything, but this will probably always be overridden by the derived class. - :param fits: An open fitsio.FITS object - :param extname: The base name of the extension + :param writer: A writer object that encapsulates the serialization format. """ raise NotImplementedError("Derived classes must define the _finish_write method.") diff --git a/piff/knn_interp.py b/piff/knn_interp.py index fbbf5b5e..36969c47 100644 --- a/piff/knn_interp.py +++ b/piff/knn_interp.py @@ -156,13 +156,12 @@ def interpolateList(self, stars, logger=None): stars_fitted.append(Star(star.data, fit)) return stars_fitted - def _finish_write(self, fits, extname): + def _finish_write(self, writer): """Write the solution to a FITS binary table. Save the knn params and the locations and targets arrays - :param fits: An open fitsio.FITS object. - :param extname: The base name of the extension with the interp information. + :param writer: A writer object that encapsulates the serialization format. """ dtypes = [('LOCATIONS', self.locations.dtype, self.locations.shape), @@ -173,8 +172,7 @@ def _finish_write(self, fits, extname): data['LOCATIONS'] = self.locations data['TARGETS'] = self.targets - # write to fits - fits.write_table(data, extname=extname + '_solution') + writer.write_table('solution', data) def _finish_read(self, fits, extname): """Read the solution from a FITS binary table. diff --git a/piff/mean_interp.py b/piff/mean_interp.py index 1ba1ad0a..79cc6b3a 100644 --- a/piff/mean_interp.py +++ b/piff/mean_interp.py @@ -60,16 +60,12 @@ def interpolate(self, star, logger=None): fit = star.fit.newParams(self.mean, num=self._num) return Star(star.data, fit) - def _finish_write(self, fits, extname): + def _finish_write(self, writer): """Write the solution to a FITS binary table. - :param fits: An open fitsio.FITS object. - :param extname: The base name of the extension + :param writer: A writer object that encapsulates the serialization format. """ - cols = [ self.mean ] - dtypes = [ ('mean', float) ] - data = np.array(list(zip(*cols)), dtype=dtypes) - fits.write_table(data, extname=extname + '_solution') + writer.write_struct('solution', {'mean': self.mean}) def _finish_read(self, fits, extname): """Read the solution from a FITS binary table. diff --git a/piff/model.py b/piff/model.py index 393a4057..ba811149 100644 --- a/piff/model.py +++ b/piff/model.py @@ -16,10 +16,7 @@ .. module:: model """ -import numpy as np -import galsim - -from .util import write_kwargs, read_kwargs +from .util import read_kwargs from .star import Star @@ -157,30 +154,41 @@ def draw(self, star, copy_image=True): def write(self, fits, extname): """Write a Model to a FITS file. + This method exists for backwards compatibility; subclasses should + reimplement _write or _finish_write instead. + + :param fits: An open fitsio.FITS object + :param extname: The name of the extension to write the model information. + """ + from .writers import FitsWriter + self._write(FitsWriter(fits, None, {}), extname) + + def _write(self, writer, name): + """Write a Model via a Writer object. + Note: this only writes the initialization kwargs to the fits extension, not the parameters. The base class implemenation works if the class has a self.kwargs attribute and these are all simple values (str, float, or int) - :param fits: An open fitsio.FITS object - :param extname: The name of the extension to write the model information. + :param writer: A writer object that encapsulates the serialization format. + :param name: A name to associate with this model in the serialized output. """ # First write the basic kwargs that works for all Model classes model_type = self._type_name - write_kwargs(fits, extname, dict(self.kwargs, type=model_type)) - + writer.write_struct(name, dict(self.kwargs, type=model_type)) # Now do any class-specific steps. - self._finish_write(fits, extname) + with writer.nested(name) as w: + self._finish_write(w) - def _finish_write(self, fits, extname): + def _finish_write(self, writer): """Finish the writing process with any class-specific steps. The base class implementation doesn't do anything, which is often appropriate, but this hook exists in case any Model classes need to write extra information to the fits file. - :param fits: An open fitsio.FITS object - :param extname: The base name of the extension + :param writer: A writer object that encapsulates the serialization format. """ pass diff --git a/piff/outliers.py b/piff/outliers.py index fbb6bf93..0d677e52 100644 --- a/piff/outliers.py +++ b/piff/outliers.py @@ -22,7 +22,7 @@ import galsim from scipy.stats import chi2 -from .util import write_kwargs, read_kwargs +from .util import read_kwargs class Outliers(object): """The base class for handling outliers. @@ -97,22 +97,32 @@ def write(self, fits, extname): :param fits: An open fitsio.FITS object :param extname: The name of the extension to write the outliers information. """ + from .writers import FitsWriter + self._write(FitsWriter(fits, None, {}), extname) + + def _write(self, writer, name): + """Write an Outers via a Writer object. + + :param writer: A writer object that encapsulates the serialization format. + :param name: A name to associate with the Ootliers in the serialized output. + """ # First write the basic kwargs that works for all Outliers classes outliers_type = self._type_name - write_kwargs(fits, extname, dict(self.kwargs, type=outliers_type)) + writer.write_struct(name, dict(self.kwargs, type=outliers_type)) # Now do any class-specific steps. - self._finish_write(fits, extname) + with writer.nested(name) as w: + self._finish_write(w) - def _finish_write(self, fits, extname): + def _finish_write(self, writer): """Finish the writing process with any class-specific steps. The base class implementation doesn't do anything, which is often appropriate, but this hook exists in case any Outliers classes need to write extra information to the fits file. - :param fits: An open fitsio.FITS object - :param extname: The base name of the extension + :param writer: A writer object that encapsulates the serialization format. + :param name: A name to associate with the Ootliers in the serialized output. """ pass diff --git a/piff/polynomial_interp.py b/piff/polynomial_interp.py index 9e3c72ae..91de4577 100644 --- a/piff/polynomial_interp.py +++ b/piff/polynomial_interp.py @@ -320,11 +320,10 @@ def model(uv,*coeffs): # self._unpack_coefficients self.coeffs = coeffs - def _finish_write(self, fits, extname): + def _finish_write(self, writer): """Write the solution to a FITS binary table. - :param fits: An open fitsio.FITS object. - :param extname: The base name of the extension + :param writer: A writer object that encapsulates the serialization format. """ if self.coeffs is None: raise RuntimeError("Coeffs not set yet. Cannot write this Polynomial.") @@ -368,7 +367,7 @@ def _finish_write(self, fits, extname): # Finally, write all of this to a FITS table. data = np.array(list(zip(*cols)), dtype=dtypes) - fits.write_table(data, extname=extname + '_solution', header=header) + writer.write_table('solution', data, metadata=header) def _finish_read(self, fits, extname): diff --git a/piff/psf.py b/piff/psf.py index fbb4ccd6..d216e343 100644 --- a/piff/psf.py +++ b/piff/psf.py @@ -19,10 +19,10 @@ import numpy as np import fitsio import galsim -import sys from .star import Star, StarData -from .util import write_kwargs, read_kwargs +from .util import read_kwargs +from .writers import Writer class PSF(object): """The base class for describing a PSF model across a field of view. @@ -708,7 +708,6 @@ def _drawStar(self, star): # their actual PSF model. raise NotImplementedError("Derived classes must define the _drawStar function") - def _getProfile(self, star): prof, method = self._getRawProfile(star) prof = prof.shift(star.fit.center) * star.fit.flux return prof, method @@ -725,31 +724,29 @@ def write(self, file_name, logger=None): logger = galsim.config.LoggerWrapper(logger) logger.warning("Writing PSF to file %s",file_name) - with fitsio.FITS(file_name,'rw',clobber=True) as f: - self._write(f, 'psf', logger) + with Writer.open(file_name) as w: + self._write(w, logger) - def _write(self, fits, extname, logger=None): + def _write(self, writer, name, logger): """This is the function that actually does the work for the write function. Composite PSF classes that need to iterate can call this multiple times as needed. - :param fits: An open fitsio.FITS object - :param extname: The name of the extension with the psf information. + :param writer: A writer object that encapsulates the serialization format. + :param name: A name to associate with this PSF in the serialized output. :param logger: A logger object for logging debug info. """ from . import __version__ as piff_version - if len(fits) == 1: - header = {'piff_version': piff_version} - fits.write(data=None, header=header) psf_type = self._type_name - write_kwargs(fits, extname, dict(self.kwargs, type=psf_type, piff_version=piff_version)) - logger.info("Wrote the basic PSF information to extname %s", extname) - if hasattr(self, 'stars'): - Star.write(self.stars, fits, extname=extname + '_stars') - logger.info("Wrote the PSF stars to extname %s", extname + '_stars') - if hasattr(self, 'wcs'): - self.writeWCS(fits, extname=extname + '_wcs', logger=logger) - logger.info("Wrote the PSF WCS to extname %s", extname + '_wcs') - self._finish_write(fits, extname=extname, logger=logger) + writer.write_struct(name, dict(self.kwargs, type=psf_type, piff_version=piff_version)) + with writer.nested(name) as w: + logger.info("Wrote the basic PSF information to name %s", w.get_full_name('psf')) + if hasattr(self, 'stars'): + Star.write(self.stars, w, 'stars') + logger.info("Wrote the PSF stars to name %s", w.get_full_name('stars')) + if hasattr(self, 'wcs'): + w.write_wcs_map('wcs', self.wcs, self.pointing) + logger.info("Wrote the PSF WCS to name %s", w.get_full_name('wcs')) + self._finish_write(w, logger=logger) @classmethod def read(cls, file_name, logger=None): @@ -817,54 +814,6 @@ def _read(cls, fits, extname, logger): return psf - def writeWCS(self, fits, extname, logger): - """Write the WCS information to a FITS file. - - :param fits: An open fitsio.FITS object - :param extname: The name of the extension to write to - :param logger: A logger object for logging debug info. - """ - import base64 - try: - import cPickle as pickle - except ImportError: - import pickle - logger = galsim.config.LoggerWrapper(logger) - - # Start with the chipnums - chipnums = list(self.wcs.keys()) - cols = [ chipnums ] - dtypes = [ ('chipnums', int) ] - - # GalSim WCS objects can be serialized via pickle - wcs_str = [ base64.b64encode(pickle.dumps(w)) for w in self.wcs.values() ] - max_len = np.max([ len(s) for s in wcs_str ]) - # Some GalSim WCS serializations are rather long. In particular, the Pixmappy one - # is longer than the maximum length allowed for a column in a fits table (28799). - # So split it into chunks of size 2**14 (mildly less than this maximum). - chunk_size = 2**14 - nchunks = max_len // chunk_size + 1 - cols.append( [nchunks]*len(chipnums) ) - dtypes.append( ('nchunks', int) ) - - # Update to size of chunk we actually need. - chunk_size = (max_len + nchunks - 1) // nchunks - - chunks = [ [ s[i:i+chunk_size] for i in range(0, max_len, chunk_size) ] for s in wcs_str ] - cols.extend(zip(*chunks)) - dtypes.extend( ('wcs_str_%04d'%i, bytes, chunk_size) for i in range(nchunks) ) - - if self.pointing is not None: - # Currently, there is only one pointing for all the chips, but write it out - # for each row anyway. - dtypes.extend( (('ra', float), ('dec', float)) ) - ra = [self.pointing.ra / galsim.hours] * len(chipnums) - dec = [self.pointing.dec / galsim.degrees] * len(chipnums) - cols.extend( (ra, dec) ) - - data = np.array(list(zip(*cols)), dtype=dtypes) - fits.write_table(data, extname=extname) - @classmethod def readWCS(cls, fits, extname, logger): """Read the WCS information from a FITS file. diff --git a/piff/simplepsf.py b/piff/simplepsf.py index b48a1de7..68889f73 100644 --- a/piff/simplepsf.py +++ b/piff/simplepsf.py @@ -242,11 +242,10 @@ def _drawStar(self, star): def _getRawProfile(self, star): return self.model.getProfile(star.fit.get_params(self._num)), self.model._method - def _finish_write(self, fits, extname, logger): + def _finish_write(self, writer, logger): """Finish the writing process with any class-specific steps. - :param fits: An open fitsio.FITS object - :param extname: The base name of the extension to write to. + :param writer: A writer object that encapsulates the serialization format. :param logger: A logger object for logging debug info. """ logger = galsim.config.LoggerWrapper(logger) @@ -257,15 +256,15 @@ def _finish_write(self, fits, extname, logger): 'nremoved' : self.nremoved, 'niter' : self.niter, } - write_kwargs(fits, extname + '_chisq', chisq_dict) - logger.debug("Wrote the chisq info to extension %s",extname + '_chisq') - self.model.write(fits, extname + '_model') - logger.debug("Wrote the PSF model to extension %s",extname + '_model') - self.interp.write(fits, extname + '_interp') - logger.debug("Wrote the PSF interp to extension %s",extname + '_interp') + writer.write_struct('chisq', chisq_dict) + logger.debug("Wrote the chisq info to %s", writer.get_full_name('chisq')) + self.model._write(writer, 'model') + logger.debug("Wrote the PSF model to %s", writer.get_full_name('model')) + self.interp._write(writer, 'interp') + logger.debug("Wrote the PSF interp to %s", writer.get_full_name('interp')) if self.outliers: - self.outliers.write(fits, extname + '_outliers') - logger.debug("Wrote the PSF outliers to extension %s",extname + '_outliers') + self.outliers._write(writer, 'outliers') + logger.debug("Wrote the PSF outliers to %s", writer.get_full_name('outliers')) def _finish_read(self, fits, extname, logger): """Finish the reading process with any class-specific steps. diff --git a/piff/singlechip.py b/piff/singlechip.py index f9fdff79..ae6f4a2a 100644 --- a/piff/singlechip.py +++ b/piff/singlechip.py @@ -178,14 +178,13 @@ def _getRawProfile(self, star): chipnum = star['chipnum'] return self.psf_by_chip[chipnum]._getRawProfile(star) - def _finish_write(self, fits, extname, logger): + def _finish_write(self, writer, logger): """Finish the writing process with any class-specific steps. - :param fits: An open fitsio.FITS object - :param extname: The base name of the extension to write to. + :param writer: A writer object that encapsulates the serialization format. :param logger: A logger object for logging debug info. """ - # Write the colnums to an extension. + # Write the colnums to a table. chipnums = list(self.psf_by_chip.keys()) chipnums = [c for c in chipnums if self.psf_by_chip[c] is not None] dt = make_dtype('chipnums', chipnums[0]) @@ -193,11 +192,11 @@ def _finish_write(self, fits, extname, logger): cols = [ chipnums ] dtypes = [ dt ] data = np.array(list(zip(*cols)), dtype=dtypes) - fits.write_table(data, extname=extname + '_chipnums') + writer.write_table('chipnums', data) - # Add _1, _2, etc. to the extname for the psf model of each chip. + # Append 1, 2, etc. to the name for the psf model of each chip. for chipnum in chipnums: - self.psf_by_chip[chipnum]._write(fits, extname + '_%s'%chipnum, logger) + self.psf_by_chip[chipnum]._write(writer, str(chipnum), logger) def _finish_read(self, fits, extname, logger): """Finish the reading process with any class-specific steps. diff --git a/piff/star.py b/piff/star.py index 0ceda3bd..7cc9f383 100644 --- a/piff/star.py +++ b/piff/star.py @@ -328,13 +328,24 @@ def makeTarget(cls, x=None, y=None, u=None, v=None, properties={}, wcs=None, sca return cls(data, fit) @classmethod - def write(self, stars, fits, extname): + def write(cls, stars, fits, extname): """Write a list of stars to a FITS file. :param stars: A list of stars to write :param fits: An open fitsio.FITS object :param extname: The name of the extension to write to """ + from .writers import FitsWriter + cls._write(stars, FitsWriter(fits, None, {}), extname) + + @classmethod + def _write(cls, stars, writer, name): + """Write a list of stars to a Writer object. + + :param stars: A list of stars to write + :param writer: A writer object that encapsulates the serialization format. + :param name: A name to associate with these stars in the serialized output. + """ # TODO This doesn't write everything out. Probably want image as an optional I/O. cols = [] @@ -406,7 +417,7 @@ def write(self, stars, fits, extname): cols.append( [s.data.pointing.dec / galsim.degrees for s in stars ] ) data = np.array(list(zip(*cols)), dtype=dtypes) - fits.write_table(data, extname=extname, header=header) + writer.write_table(name, data, metadata=header) @classmethod def read_coords_params(cls, fits, extname): diff --git a/piff/sumpsf.py b/piff/sumpsf.py index 75c9d865..c61486b2 100644 --- a/piff/sumpsf.py +++ b/piff/sumpsf.py @@ -288,11 +288,10 @@ def _getRawProfile(self, star): # Add them up. return galsim.Sum(profiles), method - def _finish_write(self, fits, extname, logger): + def _finish_write(self, writer, logger): """Finish the writing process with any class-specific steps. - :param fits: An open fitsio.FITS object - :param extname: The base name of the extension to write to. + :param writer: A writer object that encapsulates the serialization format. :param logger: A logger object for logging debug info. """ logger = galsim.config.LoggerWrapper(logger) @@ -302,13 +301,13 @@ def _finish_write(self, fits, extname, logger): 'dof' : self.dof, 'nremoved' : self.nremoved, } - write_kwargs(fits, extname + '_chisq', chisq_dict) - logger.debug("Wrote the chisq info to extension %s",extname + '_chisq') + writer.write_struct('chisq', chisq_dict) + logger.debug("Wrote the chisq info to %s", writer.get_full_name('chisq')) for k, comp in enumerate(self.components): - comp._write(fits, extname + '_' + str(k), logger=logger) + comp._write(writer, str(k), logger=logger) if self.outliers: - self.outliers.write(fits, extname + '_outliers') - logger.debug("Wrote the PSF outliers to extension %s",extname + '_outliers') + self.outliers._write(writer, 'outliers') + logger.debug("Wrote the PSF outliers to %s", writer.get_full_name('outliers')) def _finish_read(self, fits, extname, logger): """Finish the reading process with any class-specific steps. diff --git a/piff/util.py b/piff/util.py index 7503e9f5..8fd03370 100644 --- a/piff/util.py +++ b/piff/util.py @@ -101,29 +101,6 @@ def adjust_value(value, dtype): # For other numpy arrays, we can use astype instead. return np.array(value).astype(t) -def write_kwargs(fits, extname, kwargs): - """A helper function for writing a single row table into a fits file with the values - and column names given by a kwargs dict. - - :param fits: An open fitsio.FITS instance - :param extname: The extension to write to - :param kwargs: A kwargs dict to be written as a FITS binary table. - """ - from . import __version__ as piff_version - cols = [] - dtypes = [] - for key, value in kwargs.items(): - # Don't add values that are None to the table. - if value is None: - continue - dt = make_dtype(key, value) - value = adjust_value(value,dt) - cols.append([value]) - dtypes.append(dt) - data = np.array(list(zip(*cols)), dtype=dtypes) - header = {'piff_version': piff_version} - fits.write_table(data, extname=extname, header=header) - def read_kwargs(fits, extname): """A helper function for reading a single row table from a fits file returning the values and column names as a kwargs dict. diff --git a/piff/writers.py b/piff/writers.py new file mode 100644 index 00000000..8405777b --- /dev/null +++ b/piff/writers.py @@ -0,0 +1,257 @@ +# Copyright (c) 2024 by Mike Jarvis and the other collaborators on GitHub at +# https://github.com/rmjarvis/Piff All rights reserved. +# +# Piff is free software: Redistribution and use in source and binary forms +# with or without modification, are permitted provided that the following +# conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the disclaimer given in the accompanying LICENSE +# file. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the disclaimer given in the documentation +# and/or other materials provided with the distribution. + +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager, AbstractContextManager + +import os +import fitsio +import galsim +import numpy as np + +from .util import make_dtype, adjust_value + + +class Writer: + @contextmanager + @staticmethod + def open(file_name: str): + _, ext = os.path.splitext(file_name) + if ext == ".fits": + with FitsWriter._open(file_name) as writer: + yield writer + return + else: + raise NotImplementedError("No writer for extension {ext!r}.") + + def write_struct( + self, name: str, struct: dict, metadata: dict | None = None + ) -> None: + raise NotImplementedError() + + def write_table( + self, name: str, table: np.ndarray, metadata: dict | None = None + ) -> None: + raise NotImplementedError() + + def write_array( + self, name: str, array: np.ndarray, metadata: dict | None = None + ) -> None: + raise NotImplementedError() + + def write_wcs_map( + self, + name: str, + wcs_map: dict[int, galsim.BaseWCS], + pointing: galsim.CelestialCoord | None, + ) -> None: + raise NotImplementedError() + + def nested(self, name: str) -> AbstractContextManager[Writer]: + raise NotImplementedError() + + def get_full_name(self, name: str) -> str: + raise NotImplementedError() + + +class FitsWriter(Writer): + def __init__(self, fits: fitsio.FITS, base_name: str | None, header: dict): + self._fits = fits + self._base_name = base_name + self._header = header + + @contextmanager + @classmethod + def _open(cls, file_name: str): + from . import __version__ as piff_version + + header = {"piff_version": piff_version} + with fitsio.FITS(file_name, "rw", clobber=True) as f: + if len(f) == 1: + f.write(data=None, header=header) + yield cls(f, base_name=None, header=header.copy()) + + def write_struct( + self, name: str, struct: dict, metadata: dict | None = None + ) -> None: + cols = [] + dtypes = [] + for key, value in struct.items(): + # Don't add values that are None to the table. + if value is None: + continue + dt = make_dtype(key, value) + value = adjust_value(value, dt) + cols.append([value]) + dtypes.append(dt) + table = np.array(list(zip(*cols)), dtype=dtypes) + return self.write_table(name, table, metadata=metadata) + + def write_table( + self, name: str, array: np.ndarray, metadata: dict | None = None + ) -> None: + if metadata: + header = self._header.copy() + header.update(metadata) + else: + header = self._header + self._fits.write_table(array, extname=self.get_full_name(name), header=header) + + def write_array( + self, name: str, array: np.ndarray, metadata: dict | None = None + ) -> None: + if metadata: + header = self._header.copy() + header.update(metadata) + else: + header = self._header + self._fits.write(array, extname=self.get_full_name(name), header=self._header) + + def write_wcs_map( + self, + name: str, + wcs_map: dict[int, galsim.BaseWCS], + pointing: galsim.CelestialCoord | None, + ) -> None: + import base64 + + try: + import cPickle as pickle + except ImportError: + import pickle + # Start with the chipnums + chipnums = list(wcs_map.keys()) + cols = [chipnums] + dtypes = [("chipnums", int)] + + # GalSim WCS objects can be serialized via pickle + wcs_str = [base64.b64encode(pickle.dumps(w)) for w in wcs_map.values()] + max_len = np.max([len(s) for s in wcs_str]) + # Some GalSim WCS serializations are rather long. In particular, the Pixmappy one + # is longer than the maximum length allowed for a column in a fits table (28799). + # So split it into chunks of size 2**14 (mildly less than this maximum). + chunk_size = 2**14 + nchunks = max_len // chunk_size + 1 + cols.append([nchunks] * len(chipnums)) + dtypes.append(("nchunks", int)) + + # Update to size of chunk we actually need. + chunk_size = (max_len + nchunks - 1) // nchunks + + chunks = [ + [s[i : i + chunk_size] for i in range(0, max_len, chunk_size)] + for s in wcs_str + ] + cols.extend(zip(*chunks)) + dtypes.extend(("wcs_str_%04d" % i, bytes, chunk_size) for i in range(nchunks)) + + if pointing is not None: + # Currently, there is only one pointing for all the chips, but write it out + # for each row anyway. + dtypes.extend((("ra", float), ("dec", float))) + ra = [pointing.ra / galsim.hours] * len(chipnums) + dec = [pointing.dec / galsim.degrees] * len(chipnums) + cols.extend((ra, dec)) + + data = np.array(list(zip(*cols)), dtype=dtypes) + self.write_table(data, name) + + @contextmanager + def nested(self, name: str) -> Iterator[FitsWriter]: + yield FitsWriter( + self._fits, base_name=self.get_full_name(name), header=self._header + ) + + def get_full_name(self, name: str) -> str: + return name if self._base_name is None else f"{self._base_name}_{name}" + + +class DictWriter(Writer): + def __init__(self, path: str, data: dict): + self._path = path + self._data = data + + def write_struct(self, name: str, struct: dict) -> None: + if name in self._data: + if not isinstance(self._data[name], dict): + raise AssertionError( + f"Key conflict at {self.get_full_name(name)} is written by multiple sub-objects.." + " This is a logic bug in serialization." + ) + if struct.keys().isdisjoint(self._data[name].keys()): + raise AssertionError( + f"Key conflict at {self.get_full_name(name)}: {struct.keys() & self._data[name].keys()}" + " written by different sub-objects. This is a logic bug in serialization." + ) + self._data[name].update(struct) + else: + self._data[name] = struct.copy() + + def write_table( + self, name: str, table: np.ndarray, metadata: dict | None = None + ) -> None: + # Default implementation just dumps the numpy structured array as-is + # into the output nested dictionary (with a little nesting to make room + # for metadata and make reading more type-safe), but this makes writing + # directly to JSON or YAML impossible. Subclasses can override this + # method to do something more clever. + struct = {"type": "table", "data": table} + if metadata: + struct.update(metadata) + self.write_struct(name, struct) + + def write_array( + self, name: str, array: np.ndarray, metadata: dict | None = None + ) -> None: + # See comment on write_table. + struct = {"type": "array", "data": array} + if metadata: + struct.update(metadata) + self.write_struct(name, struct) + + def write_wcs_map( + self, + name: str, + wcs_map: dict[int, galsim.BaseWCS], + pointing: galsim.CelestialCoord | None, + ) -> None: + import base64 + + try: + import cPickle as pickle + except ImportError: + import pickle + struct = { + "type": "wcsmap", + "chips": { + # Serialize galsim WCSs as base64-encoded pickle blobs. + str(k): base64.b64encode(pickle.dumps(v)) for k, v in wcs_map.items() + }, + } + if pointing is not None: + struct["pointing"] = { + "ra": pointing.ra / galsim.hours, + "dec": pointing.dec / galsim.degrees, + } + self.write_struct(name, struct) + + @contextmanager + def nested(self, name: str) -> Iterator[Writer]: + nested_data = self._data.setdefault(name, {}) + yield DictWriter(self.get_full_name(name), nested_data) + + def get_full_name(self, name: str) -> str: + return f"{self._path}/{name}"