diff --git a/.gitignore b/.gitignore index 350b3c7a..0fd3cf00 100644 --- a/.gitignore +++ b/.gitignore @@ -66,3 +66,6 @@ target/ # local sandbox local/ + +# local flake8 running +.flake8 diff --git a/skycatalogs/catalog_creator.py b/skycatalogs/catalog_creator.py index 9bc1599e..255cd858 100644 --- a/skycatalogs/catalog_creator.py +++ b/skycatalogs/catalog_creator.py @@ -1,9 +1,7 @@ import os import sys import re -import math import logging -import yaml import numpy as np import numpy.ma as ma import healpy @@ -11,19 +9,18 @@ import pyarrow as pa import pyarrow.parquet as pq from multiprocessing import Process, Pipe -from astropy.coordinates import SkyCoord import sqlite3 -from .utils.common_utils import print_date from .utils.sed_tools import TophatSedFactory, get_star_sed_path from .utils.config_utils import create_config, assemble_SED_models -from .utils.config_utils import assemble_MW_extinction, assemble_cosmology, assemble_object_types, assemble_provenance, write_yaml -from .utils.parquet_schema_utils import make_galaxy_schema, make_galaxy_flux_schema, make_star_flux_schema, make_pointsource_schema +from .utils.config_utils import assemble_MW_extinction, assemble_cosmology +from .utils.config_utils import assemble_object_types, assemble_provenance +from .utils.config_utils import write_yaml +from .utils.parquet_schema_utils import make_galaxy_schema +from .utils.parquet_schema_utils import make_galaxy_flux_schema +from .utils.parquet_schema_utils import make_star_flux_schema +from .utils.parquet_schema_utils import make_pointsource_schema from .utils.creator_utils import make_MW_extinction_av, make_MW_extinction_rv from .objects.base_object import LSST_BANDS -from .objects.base_object import ObjectCollection - -# from dm stack -from dustmaps.sfd import SFDQuery """ Code to create a sky catalog for particular object types @@ -33,6 +30,7 @@ _MW_rv_constant = 3.1 + def _generate_sed_path(ids, subdir, cmp): ''' Generate paths (e.g. relative to SIMS_SED_LIBRARY_DIR) for galaxy component @@ -51,6 +49,7 @@ def _generate_sed_path(ids, subdir, cmp): r = [f'{subdir}/{cmp}_{id}.txt' for id in ids] return r + def _get_tophat_info(columns): ''' Parameters @@ -83,6 +82,7 @@ def _bin_start_key(start_width): # Moving on to tophat_fetch def _sed_bulge_key(s): return int(re.match(tophat_bulge_re, s)['start']) + def _sed_disk_key(s): return int(re.match(tophat_disk_re, s)['start']) @@ -92,8 +92,10 @@ def _sed_disk_key(s): return sed_bins, sed_bulge_names, sed_disk_names + _nside_allowed = 2**np.arange(15) + def _find_subpixels(pixel, subpixel_nside, pixel_nside=32, nest=False): ''' Return list of pixels of specified nside inside a given pixel @@ -102,12 +104,12 @@ def _find_subpixels(pixel, subpixel_nside, pixel_nside=32, nest=False): pixel int the id of the input pixel subpixel_nside int nside for subpixels pixel_nside int nside of original pixel (default=32) - nest boolean True if pixel ordering for original pixel is nested - (default = False) + nest boolean True if pixel ordering for original pixel is + nested (default = False) Returns ------- - List of subpixel ids (nested ordering iff original was). If subpixel resolution - is no better than original, just return original pixel id + List of subpixel ids (nested ordering iff original was). If subpixel + resolution is no better than original, just return original pixel id ''' if pixel_nside not in _nside_allowed: raise ValueError(f'Disallowed pixel nside value {pixel_nside}') @@ -135,11 +137,13 @@ def _next_level(pixel): else: return [healpy.nest2ring(subpixel_nside, p) for p in pixels] + def _generate_subpixel_masks(ra, dec, subpixels, nside=32): ''' - Given ra, dec values for objects within a particular pixel and a list of its - subpixels for some greater value of nside, return dict with subpixel ids as - keys and values a mask which masks off all values except those belonging to subpixel + Given ra, dec values for objects within a particular pixel and a list of + its subpixels for some greater value of nside, return dict with subpixel + ids as keys and values a mask which masks off all values except those + belonging to subpixel Parameters ---------- @@ -175,9 +179,9 @@ def _do_galaxy_flux_chunk(send_conn, galaxy_collection, l_bnd, u_bnd): ''' out_dict = {} - o_list = galaxy_collection[l_bnd : u_bnd] + o_list = galaxy_collection[l_bnd: u_bnd] out_dict['galaxy_id'] = [o.get_native_attribute('galaxy_id') for o in o_list] - all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list] + all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list] all_fluxes_transpose = zip(*all_fluxes) for i, band in enumerate(LSST_BANDS): v = all_fluxes_transpose.__next__() @@ -188,6 +192,7 @@ def _do_galaxy_flux_chunk(send_conn, galaxy_collection, l_bnd, u_bnd): else: return out_dict + class CatalogCreator: def __init__(self, parts, area_partition=None, skycatalog_root=None, catalog_dir='.', galaxy_truth=None, @@ -332,7 +337,7 @@ def _make_tophat_columns(self, dat, names, cmp): ''' sed_vals = (np.array([dat[k] for k in names]).T).tolist() dat['sed_val_' + cmp] = sed_vals - dat[cmp + '_magnorm'] = [self._obs_sed_factory.magnorm(s, z) for (s, z)\ + dat[cmp + '_magnorm'] = [self._obs_sed_factory.magnorm(s, z) for (s, z) in zip(sed_vals, dat['redshiftHubble'])] for k in names: del(dat[k]) @@ -361,7 +366,7 @@ def create(self, catalog_type): if not self._main_only: self.create_pointsource_flux_catalog() else: - raise NotImplemented(f'CatalogCreator.create: unsupported catalog type {catalog_type}') + raise NotImplementedError(f'CatalogCreator.create: unsupported catalog type {catalog_type}') def create_galaxy_catalog(self): """ @@ -382,9 +387,8 @@ def create_galaxy_catalog(self): # Save cosmology in case we need to write parameters out later self._cosmology = gal_cat.cosmology - arrow_schema = make_galaxy_schema(self._logname, - self._sed_subdir, - self._knots) + arrow_schema = make_galaxy_schema(self._logname, self._sed_subdir, + self._knots) for p in self._parts: self._logger.info(f'Starting on pixel {p}') @@ -418,8 +422,8 @@ def _write_subpixel(self, dat=None, output_path=None, arrow_schema=None, for val in dat.values(): dlen = len(val) break - if dlen == 0: return - + if dlen == 0: + return last_row_ix = dlen - 1 u_bnd = min(stride, dlen) l_bnd = 0 @@ -427,9 +431,9 @@ def _write_subpixel(self, dat=None, output_path=None, arrow_schema=None, writer = None while u_bnd > l_bnd: - out_dict = {k : dat[k][l_bnd : u_bnd] for k in dat if k not in to_rename} + out_dict = {k: dat[k][l_bnd: u_bnd] for k in dat if k not in to_rename} for k in to_rename: - out_dict[to_rename[k]] = dat[k][l_bnd : u_bnd] + out_dict[to_rename[k]] = dat[k][l_bnd: u_bnd] out_df = pd.DataFrame.from_dict(out_dict) out_table = pa.Table.from_pandas(out_df, schema=arrow_schema) if not writer: @@ -511,7 +515,7 @@ def create_galaxy_pixel(self, pixel, gal_cat, arrow_schema): self._obs_sed_factory = TophatSedFactory(self._sed_bins, assemble_cosmology(self._cosmology)) - #Fetch the data + # Fetch the data to_fetch = non_sed + sed_bulge_names + sed_disk_names # df is not a dataframe! It's just a dict @@ -527,8 +531,8 @@ def create_galaxy_pixel(self, pixel, gal_cat, arrow_schema): self._logger.debug('Made extinction') # Some columns need to be renamed - to_rename = {'redshiftHubble' : 'redshift_hubble', - 'peculiarVelocity' : 'peculiar_velocity'} + to_rename = {'redshiftHubble': 'redshift_hubble', + 'peculiarVelocity': 'peculiar_velocity'} if self._dc2: to_rename['ellipticity_1_disk_true_dc2'] = 'ellipticity_1_disk_true' to_rename['ellipticity_2_disk_true_dc2'] = 'ellipticity_2_disk_true' @@ -547,19 +551,22 @@ def create_galaxy_pixel(self, pixel, gal_cat, arrow_schema): # adjust disk sed; create knots sed sed_knot_names = [i.replace('disk', 'knots') for i in sed_disk_names] eps = np.finfo(np.float32).eps - mag_mask = np.where(np.array(df['mag_i_lsst']) > self._knots_mag_cut,0, 1) + mag_mask = np.where(np.array(df['mag_i_lsst']) > self._knots_mag_cut, 0, 1) self._logger.debug(f'Count of mags <= cut (so adjustment performed: {np.count_nonzero(mag_mask)}') for d_name, k_name in zip(sed_disk_names, sed_knot_names): - df[k_name] = mag_mask * np.clip(df['knots_flux_ratio'], None, 1-eps) * df[d_name] + df[k_name] = mag_mask * np.clip(df['knots_flux_ratio'], + None, 1-eps) * df[d_name] df[d_name] = np.where(np.array(df['mag_i_lsst']) > self._knots_mag_cut, 1, - np.clip(1 - df['knots_flux_ratio'], eps, None)) * df[d_name] + np.clip(1 - df['knots_flux_ratio'], + eps, None)) * df[d_name] if len(self._out_pixels) > 1: - subpixel_masks = _generate_subpixel_masks(df['ra'], df['dec'], self._out_pixels, nside=self._galaxy_nside) + subpixel_masks = _generate_subpixel_masks(df['ra'], df['dec'], + self._out_pixels, + nside=self._galaxy_nside) else: - subpixel_masks = {pixel : None} - + subpixel_masks = {pixel: None} for p, val in subpixel_masks.items(): output_path = os.path.join(self._output_dir, f'galaxy_{p}.parquet') @@ -574,10 +581,14 @@ def create_galaxy_pixel(self, pixel, gal_cat, arrow_schema): for k in df: compressed[k] = ma.array(df[k], mask=val).compressed() - compressed = self._make_tophat_columns(compressed, sed_disk_names, 'disk') - compressed = self._make_tophat_columns(compressed, sed_bulge_names, 'bulge') + compressed = self._make_tophat_columns(compressed, + sed_disk_names, 'disk') + compressed = self._make_tophat_columns(compressed, + sed_bulge_names, 'bulge') if self._knots: - compressed = self._make_tophat_columns(compressed, sed_knot_names, 'knots') + compressed = self._make_tophat_columns(compressed, + sed_knot_names, + 'knots') self._write_subpixel(dat=compressed, output_path=output_path, arrow_schema=arrow_schema, @@ -609,7 +620,7 @@ def create_galaxy_flux_catalog(self, config_file=None): None ''' - from .skyCatalogs import open_catalog, SkyCatalog + from .skyCatalogs import open_catalog self._gal_flux_schema = make_galaxy_flux_schema(self._logname) @@ -632,7 +643,6 @@ def create_galaxy_flux_catalog(self, config_file=None): self._create_galaxy_flux_pixel(p) self._logger.info(f'Completed pixel {p}') - def _create_galaxy_flux_pixel(self, pixel): ''' Create a parquet file for a single healpix pixel containing only @@ -668,18 +678,18 @@ def _create_galaxy_flux_pixel(self, pixel): # prefetch everything we need. for att in ['galaxy_id', 'shear_1', 'shear_2', 'convergence', 'redshift_hubble', 'MW_av', 'MW_rv', 'sed_val_bulge', - 'sed_val_disk', 'sed_val_knots'] : - v = object_coll.get_native_attribute(att) + 'sed_val_disk', 'sed_val_knots']: + _ = object_coll.get_native_attribute(att) l_bnd = 0 u_bnd = len(object_coll) rg_written = 0 self._logger.debug(f'Handling range {l_bnd} up to {u_bnd}') - out_dict = {'galaxy_id': [], 'lsst_flux_u' : [], - 'lsst_flux_g' : [], 'lsst_flux_r' : [], - 'lsst_flux_i' : [], 'lsst_flux_z' : [], - 'lsst_flux_y' : []} + out_dict = {'galaxy_id': [], 'lsst_flux_u': [], + 'lsst_flux_g': [], 'lsst_flux_r': [], + 'lsst_flux_i': [], 'lsst_flux_z': [], + 'lsst_flux_y': []} n_parallel = self._flux_parallel @@ -687,13 +697,13 @@ def _create_galaxy_flux_pixel(self, pixel): n_per = u_bnd - l_bnd else: n_per = int((u_bnd - l_bnd + n_parallel)/n_parallel) - l = l_bnd + lb = l_bnd u = min(l_bnd + n_per, u_bnd) readers = [] if n_parallel == 1: out_dict = _do_galaxy_flux_chunk(None, _galaxy_collection, - l, u) + lb, u) else: # Expect to be able to do about 1500/minute/process tm = max(int((n_per*60)/500), 5) # Give ourselves a cushion @@ -706,27 +716,27 @@ def _create_galaxy_flux_pixel(self, pixel): # For debugging call directly proc = Process(target=_do_galaxy_flux_chunk, name=f'proc_{i}', - args=(conn_wrt, _galaxy_collection,l, u)) + args=(conn_wrt, _galaxy_collection, lb, u)) proc.start() p_list.append(proc) - l = u - u = min(l + n_per, u_bnd) + lb = u + u = min(lb + n_per, u_bnd) - self._logger.debug('Processes started') # outside for loop + self._logger.debug('Processes started') for i in range(n_parallel): ready = readers[i].poll(tm) if not ready: self._logger.error(f'Process {i} timed out after {tm} sec') sys.exit(1) - dat = readers[i].recv() # lines up with if + dat = readers[i].recv() for k in ['galaxy_id', 'lsst_flux_u', 'lsst_flux_g', 'lsst_flux_r', 'lsst_flux_i', 'lsst_flux_z', 'lsst_flux_y']: out_dict[k] += dat[k] - for p in p_list: # indent same as "for i in range(.." + for p in p_list: p.join() - out_df = pd.DataFrame.from_dict(out_dict) # outdent from for + out_df = pd.DataFrame.from_dict(out_dict) out_table = pa.Table.from_pandas(out_df, schema=self._gal_flux_schema) @@ -734,7 +744,7 @@ def _create_galaxy_flux_pixel(self, pixel): writer = pq.ParquetWriter(output_path, self._gal_flux_schema) writer.write_table(out_table) - rg_written +=1 + rg_written += 1 writer.close() self._logger.debug(f'# row groups written to flux file: {rg_written}') @@ -766,7 +776,7 @@ def create_pointsource_catalog(self): self._logger.debug(f'Completed pixel {p}') def create_pointsource_pixel(self, pixel, arrow_schema, star_cat=None, - sn_cat=None): + sn_cat=None): if not star_cat and not sn_cat: self._logger.info('No point source inputs specified') return @@ -785,7 +795,8 @@ def create_pointsource_pixel(self, pixel, arrow_schema, star_cat=None, if star_cat: # Get data for this pixel - cols = ','.join(['format("%s",simobjid) as id', 'ra', 'decl as dec', + cols = ','.join(['format("%s",simobjid) as id', 'ra', + 'decl as dec', 'magNorm as magnorm', 'mura', 'mudecl as mudec', 'radialVelocity as radial_velocity', 'parallax', 'sedFilename as sed_filepath', 'ebv']) @@ -818,7 +829,7 @@ def create_pointsource_pixel(self, pixel, arrow_schema, star_cat=None, 'sndec_in as dec', 'galaxy_id as host_galaxy_id']) params = ','.join(['z_in as z', 't0_in as t0, x0_in as x0', - 'x1_in as x1', 'c_in as c']) + 'x1_in as x1', 'c_in as c']) q1 = f'select {cols} from sne_params where hpid={pixel} ' q2 = f'select {params} from sne_params where hpid={pixel} ' @@ -833,7 +844,7 @@ def create_pointsource_pixel(self, pixel, arrow_schema, star_cat=None, sn_df['MW_av'] = make_MW_extinction_av(sn_df['ra'], sn_df['dec']) # Add fillers for columns not relevant for sn - sn_df['sed_filepath'] = np.full((nobj),'') + sn_df['sed_filepath'] = np.full((nobj), '') sn_df['magnorm'] = np.full((nobj,), None) sn_df['mura'] = np.full((nobj,), None) sn_df['mudec'] = np.full((nobj,), None) @@ -869,7 +880,7 @@ def create_pointsource_flux_catalog(self, config_file=None): None ''' - from .skyCatalogs import open_catalog, SkyCatalog + from .skyCatalogs import open_catalog self._ps_flux_schema = make_star_flux_schema(self._logname) if not config_file: @@ -924,11 +935,11 @@ def _create_pointsource_flux_pixel(self, pixel): l_bnd = 0 u_bnd = last_row_ix + 1 - o_list = object_list[l_bnd : u_bnd] + o_list = object_list[l_bnd: u_bnd] self._logger.debug(f'Handling range {l_bnd} up to {u_bnd}') out_dict = {} out_dict['id'] = [o.get_native_attribute('id') for o in o_list] - all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list] + all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list] all_fluxes_transpose = zip(*all_fluxes) for i, band in enumerate(LSST_BANDS): self._logger.debug(f'Band {band} is number {i}') @@ -945,7 +956,7 @@ def _create_pointsource_flux_pixel(self, pixel): writer.write_table(out_table) writer.close() - ##self._logger.debug(f'# row groups written to flux file: {rg_written}') + # self._logger.debug(f'#row groups written to flux file: {rg_written}') if self._provenance == 'yaml': self.write_provenance_file(output_path) @@ -971,14 +982,14 @@ def write_config(self, overwrite=False, path_only=False): self._config_path = self._output_dir if path_only: - return os.path.join(self._config_path, self._catalog_name + '.yaml') - + return os.path.join(self._config_path, + self._catalog_name + '.yaml') config = create_config(self._catalog_name, self._logname) if self._global_partition is not None: config.add_key('area_partition', self._area_partition) config.add_key('skycatalog_root', self._skycatalog_root) - config.add_key('catalog_dir' , self._catalog_dir) + config.add_key('catalog_dir', self._catalog_dir) config.add_key('SED_models', assemble_SED_models(self._sed_bins)) @@ -991,7 +1002,7 @@ def write_config(self, overwrite=False, path_only=False): config.add_key('galaxy_magnitude_cut', self._mag_cut) config.add_key('knots_magnitude_cut', self._knots_mag_cut) - inputs = {'galaxy_truth' : self._galaxy_truth} + inputs = {'galaxy_truth': self._galaxy_truth} if self._sn_truth: inputs['sn_truth'] = self._sn_truth if self._star_truth: diff --git a/skycatalogs/objects/base_object.py b/skycatalogs/objects/base_object.py index 885a96aa..8fcafa7d 100644 --- a/skycatalogs/objects/base_object.py +++ b/skycatalogs/objects/base_object.py @@ -1,15 +1,9 @@ from collections.abc import Sequence, Iterable from collections import namedtuple import os -import gzip import itertools import numpy as np -import astropy.units as u -import warnings -from dust_extinction.parameter_averages import F19 import galsim -import logging -from galsim.errors import GalSimRangeError from skycatalogs.utils.translate_utils import form_object_string from skycatalogs.utils.config_utils import Config @@ -26,6 +20,8 @@ LSST_BANDS = ('ugrizy') # global for easy access for code run within mp + + def load_lsst_bandpasses(): ''' Read in lsst bandpasses from standard place, trim, and store in global dict @@ -40,17 +36,17 @@ def load_lsst_bandpasses(): if os.path.exists(bp_dir): BaseObject._bp_path = bp_dir - #logger.info(f'Using rubin sim dir {rubin_sim_dir}') + # logger.info(f'Using rubin sim dir {rubin_sim_dir}') else: bp_dir = os.path.join(os.getenv('HOME'), 'rubin_sim_data', - 'throughputs', 'baseline') - #logger.info(f'Using rubin sim dir rubin_sim_data under HOME') + 'throughputs', 'baseline') + # logger.info(f'Using rubin sim dir rubin_sim_data under HOME') if os.path.exists(bp_dir): BaseObject._bp_path = bp_dir else: - #logger.info('Using galsim built-in bandpasses') + # logger.info('Using galsim built-in bandpasses') bp_dir = None - fname_fmt = 'LSST_{band}.dat' + for band in LSST_BANDS: if bp_dir: bp_full_path = os.path.join(bp_dir, f'total_{band}.dat') @@ -60,22 +56,25 @@ def load_lsst_bandpasses(): # Mirror behavior in imsim.RubinBandpass: # https://github.com/LSSTDESC/imSim/blob/main/imsim/bandpass.py#L9 - # Trim the edges to avoid 1.e-4 values out to very high and low wavelengths. + # Trim the edges to avoid 1.e-4 values out to very high and low + # wavelengths. bp = bp.truncate(relative_throughput=1.e-3) - # Remove wavelength values selectively for improved speed but preserve flux integrals. + # Remove wavelength values selectively for improved speed but + # preserve flux integrals. bp = bp.thin() bp = bp.withZeropoint('AB') lsst_bandpasses[band] = bp return lsst_bandpasses + class BaseObject(object): ''' Abstract base class for static (in position coordinates) objects. Likely need a variant for SSO. ''' - _bp500 = galsim.Bandpass(galsim.LookupTable([499, 500, 501],[0, 1, 0]), + _bp500 = galsim.Bandpass(galsim.LookupTable([499, 500, 501], [0, 1, 0]), wave_type='nm').withZeropoint('AB') def __init__(self, ra, dec, id, object_type, belongs_to, belongs_index, @@ -103,7 +102,6 @@ def __init__(self, ra, dec, id, object_type, belongs_to, belongs_index, # All objects also include redshift information. Also MW extinction, # but extinction is by subcomponent for galaxies - @property def ra(self): return self._ra @@ -122,7 +120,8 @@ def object_type(self): @property def redshift(self): - if self._redshift: return self._redshift + if self._redshift: + return self._redshift if self._belongs_to: self._redshift = self.get_native_attribute('redshift') return self._redshift @@ -173,7 +172,7 @@ def get_sed_metadata(self, **kwargs): ''' raise NotImplementedError - def get_instcat_entry(self, band = 'r', component=None): + def get_instcat_entry(self, band='r', component=None): ''' Return the string corresponding to instance catalog line Parameters: @@ -204,7 +203,6 @@ def _get_sed(self, component=None, resolution=None, mjd=None): raise NotImplementedError('Must be implemented by BaseObject subclass if needed') - def write_sed(self, sed_file_path, component=None, resolution=None, mjd=None): sed, _ = self._get_sed(component=component, resolution=None, mjd=None) @@ -284,11 +282,11 @@ def get_total_observer_sed(self, mjd=None): components. """ sed = None - for sed_component in self.get_observer_sed_components(mjd=mjd).values(): + for sed_cmp in self.get_observer_sed_components(mjd=mjd).values(): if sed is None: - sed = sed_component + sed = sed_cmp else: - sed += sed_component + sed += sed_cmp return sed @@ -318,7 +316,7 @@ def get_fluxes(self, bandpasses, mjd=None): return [sed.calculateFlux(b) for b in bandpasses] def get_LSST_flux(self, band, sed=None, cache=True, mjd=None): - if not band in LSST_BANDS: + if band not in LSST_BANDS: return None att = f'lsst_flux_{band}' @@ -463,7 +461,8 @@ def get_native_attribute(self, attribute_name): exists. ''' val = getattr(self, attribute_name, None) - if val is not None: return val + if val is not None: + return val for r in self._rdrs: if attribute_name in r.columns: @@ -510,7 +509,7 @@ def get_native_attributes_iterator(self, attribute_names): ------- iterator which returns df for a chunk of values of the attributes ''' - pass # for now + pass # for now # implement Sequence methods def __contains__(self, obj): @@ -519,7 +518,7 @@ def __contains__(self, obj): ---------- obj can be an (object id) or of type BaseObject ''' - if type(obj) == type(10): + if isinstance(obj, int): id = obj else: if isinstance(obj, BaseObject): @@ -531,8 +530,8 @@ def __contains__(self, obj): def __len__(self): return len(self._id) - #def __iter__(self): Standard impl based on __getitem__ should be ok - #def __reversed__(self): Default implementation ok + # def __iter__(self): Standard impl based on __getitem__ should be ok + # def __reversed__(self): Default implementation ok def __getitem__(self, key): ''' @@ -556,7 +555,7 @@ def __getitem__(self, key): elif type(key) == slice: if key.start is None: key.start = 0 - ixdata = [i for i in range(min(key.stop,len(self._ra)))] + ixdata = [i for i in range(min(key.stop, len(self._ra)))] ixes = itertools.islice(ixdata, key.start, key.stop, key.step) return [self._object_class(self._ra[i], self._dec[i], self._id[i], object_type, self, i) @@ -568,7 +567,6 @@ def __getitem__(self, key): object_type, self, i) for i in key[0]] - def get_partition_id(self): return self._partition_id @@ -576,7 +574,8 @@ def count(self, obj): ''' returns # of occurrences of obj. It can only be 0 or 1 ''' - if self.__contains__(obj): return 1 + if self.__contains__(obj): + return 1 return 0 def index(self, obj): @@ -585,6 +584,7 @@ def index(self, obj): ''' return self._id.index(obj.id) + LocatedCollection = namedtuple('LocatedCollection', ['collection', 'first_index', 'upper_bound']) ''' @@ -595,6 +595,7 @@ def index(self, obj): so upper_bound for one collection = first_index in the next ''' + class ObjectList(Sequence): ''' Keep track of a list of ObjectCollection objects, but from user @@ -642,13 +643,13 @@ def get_collections(self): return constituent ObjectCollection objects in a list ''' collections = [] - for e in self._located: + for e in self._located: collections.append(e.collection) return collections - # implement Sequence methods + def __contains__(self, obj): ''' Parameters @@ -665,8 +666,8 @@ def __contains__(self, obj): def __len__(self): return self._total_len - #def __iter__(self): Standard impl based on __getitem__ should be ok?? - #def __reversed__(self): Default implementation ok?? + # def __iter__(self): Standard impl based on __getitem__ should be ok?? + # def __reversed__(self): Default implementation ok?? def __getitem__(self, key): ''' @@ -709,7 +710,7 @@ def __getitem__(self, key): sub = [elem - e.first_index for elem in key_list if elem >= e.first_index and elem < e.upper_bound] to_return += e.collection[(sub,)] - start_ix +=len(sub) + start_ix += len(sub) if start_ix >= len(key_list): break start = key_list[start_ix] diff --git a/skycatalogs/objects/gaia_object.py b/skycatalogs/objects/gaia_object.py index 0aae78cd..de56a50e 100644 --- a/skycatalogs/objects/gaia_object.py +++ b/skycatalogs/objects/gaia_object.py @@ -1,8 +1,8 @@ import os -import sys import warnings from functools import wraps import itertools +from collections.abc import Iterable from pathlib import PurePath import numpy as np import erfa @@ -60,6 +60,7 @@ class GaiaObject(BaseObject): _stellar_temperature = _TEMP_FUNC _gaia_bp_bandpass = _GAIA_BP _wavelengths = np.arange(250, 1250, 5, dtype=float) + def __init__(self, obj_pars, parent_collection, index): """ Parameters @@ -92,9 +93,9 @@ def __init__(self, obj_pars, parent_collection, index): # Convert from flux units of nJy to AB mag for the bp passband, # which we will use to normalize the SED. self.bp_mag = -2.5*np.log10(bp_flux*1e-9) + 8.90 - except galsim.errors.GalSimRangeError as ex: + except galsim.errors.GalSimRangeError: self.stellar_temp = None - except RuntimeError as rex: + except RuntimeError: self.stellar_temp = None def blambda(self, wl): @@ -126,16 +127,21 @@ def get_gsobject_components(self, gsparams=None, rng=None): def set_use_lut(self, use_lut): self.use_lut = use_lut + class GaiaCollection(ObjectCollection): # Class methods _gaia_config = None - def set_config(config): + + @classmethod + def set_config(cls, config): GaiaCollection._gaia_config = config - def get_config(): + @classmethod + def get_config(cls): return GaiaCollection._gaia_config @ignore_erfa_warnings + @staticmethod def load_collection(region, skycatalog, mjd=None): if isinstance(region, Disk): ra = lsst.geom.Angle(region.ra, lsst.geom.degrees) @@ -229,13 +235,14 @@ def __getitem__(self, key): return GaiaObject(self.df.iloc[key], self, key) elif type(key) == slice: - ixdata = [i for i in range(min(key.stop,len(self._id)))] + ixdata = [i for i in range(min(key.stop, len(self._id)))] ixes = itertools.islice(ixdata, key.start, key.stop, key.step) return [self._object_class(self.df.iloc[i], self, i) for i in ixes] elif type(key) == tuple and isinstance(key[0], Iterable): # check it's a list of int-like? - return [self._object_class(self.df.iloc[i], self, i) for i in key[0]] + return [self._object_class(self.df.iloc[i], self, + i) for i in key[0]] def __len__(self): return len(self.df) diff --git a/skycatalogs/objects/galaxy_object.py b/skycatalogs/objects/galaxy_object.py index 49b8489a..65171fee 100644 --- a/skycatalogs/objects/galaxy_object.py +++ b/skycatalogs/objects/galaxy_object.py @@ -2,11 +2,14 @@ import galsim from .base_object import BaseObject +from skycatalogs.utils.translate_utils import form_object_string __all__ = ['GalaxyObject'] + class GalaxyObject(BaseObject): _type_name = 'galaxy' + def _get_sed(self, component=None, resolution=None): ''' Return sed and mag_norm for a galaxy component or for a star @@ -25,7 +28,7 @@ def _get_sed(self, component=None, resolution=None): raise ValueError(f'Cannot fetch SED for component type {component}') th_val = self.get_native_attribute(f'sed_val_{component}') - if th_val is None: # values for this component are not in the file + if th_val is None: # values for this component are not in the file raise ValueError(f'{component} not part of this catalog') # if values are all zeros or nearly no point in trying to convert @@ -46,11 +49,11 @@ def get_wl_params(self): """Return the weak lensing parameters, g1, g2, mu.""" gamma1 = self.get_native_attribute('shear_1') gamma2 = self.get_native_attribute('shear_2') - kappa = self.get_native_attribute('convergence') + kappa = self.get_native_attribute('convergence') # Compute reduced shears and magnification. g1 = gamma1/(1. - kappa) # real part of reduced shear g2 = gamma2/(1. - kappa) # imaginary part of reduced shear - mu = 1./((1. - kappa)**2 - (gamma1**2 + gamma2**2)) # magnification + mu = 1./((1. - kappa)**2 - (gamma1**2 + gamma2**2)) # magnification return g1, g2, mu def get_total_observer_sed(self, mjd=None): @@ -67,7 +70,6 @@ def get_total_observer_sed(self, mjd=None): sed *= mu return sed - def get_gsobject_components(self, gsparams=None, rng=None): if gsparams is not None: @@ -123,7 +125,7 @@ def get_observer_sed_component(self, component, mjd=None, resolution=None): return sed - def get_instcat_entry(self, band = 'r', component=None): + def get_instcat_entry(self, band='r', component=None): ''' Return the string corresponding to instance catalog line Parameters: diff --git a/skycatalogs/objects/snana_object.py b/skycatalogs/objects/snana_object.py index 6a458c25..ad643a43 100644 --- a/skycatalogs/objects/snana_object.py +++ b/skycatalogs/objects/snana_object.py @@ -2,11 +2,12 @@ import galsim import h5py import numpy as np -from .base_object import BaseObject,ObjectCollection +from .base_object import BaseObject, ObjectCollection from skycatalogs.utils.exceptions import SkyCatalogsRuntimeError __all__ = ['SnanaObject', 'SnanaCollection'] + class SnanaObject(BaseObject): _type_name = 'snana' @@ -47,7 +48,7 @@ def _flux_ratio(mag): # -0.9210340371976184 = -np.log(10)/2.5. return np.exp(-0.921034037196184 * mag) - flux = super().get_LSST_flux(band, sed, mjd) + flux = super().get_LSST_flux(band, sed=sed, cache=cache, mjd=mjd) if flux < 0: raise SkyCatalogsRuntimeError('Negative flux') @@ -58,7 +59,11 @@ def _flux_ratio(mag): mjd_ix_l, mjd_ix_u, mjd_fraction = self._find_mjd_interval(mjd) with h5py.File(self._belongs_to._SED_file, 'r') as f: - cors = f[self._id][f'magcor_{band}'] + try: + cors = f[self._id][f'magcor_{band}'] + except KeyError: + # nothing else to do + return flux # interpolate corrections if mjd_ix_l == mjd_ix_u: @@ -67,7 +72,7 @@ def _flux_ratio(mag): mag_cor = cors[mjd_ix_l] + mjd_fraction *\ (cors[mjd_ix_u] - cors[mjd_ix_l]) - #dbg = True + # dbg = True dbg = False # Do everything in flux units @@ -79,9 +84,6 @@ def _flux_ratio(mag): print(f' mag correction: {mag_cor}') print(f' multiplicative flux correction: {flux_cor}') - if cache: - att = f'lsst_flux_{band}' - setattr(self, att, corrected_flux) return corrected_flux def _find_mjd_interval(self, mjd=None): @@ -110,7 +112,7 @@ def _find_mjd_interval(self, mjd=None): # just return previously-computed values return self._mjd_ix_l, self._mjd_ix_u, self._mjd_fraction - if self._mjds is None: + if self._mjds is None: with h5py.File(self._belongs_to._SED_file, 'r') as f: self._mjds = np.array(f[self._id]['mjd']) mjds = self._mjds @@ -118,7 +120,7 @@ def _find_mjd_interval(self, mjd=None): mjd_fraction = None index = bisect.bisect(mjds, mjd) if index == 0: - mjd_ix_l = mjd_ix_u = 0 + mjd_ix_l = mjd_ix_u = 0 elif index == len(mjds): mjd_ix_l = mjd_ix_u = index - 1 else: @@ -133,13 +135,12 @@ def _find_mjd_interval(self, mjd=None): return mjd_ix_l, mjd_ix_u, mjd_fraction - def _linear_interp_SED(self, mjd=None): ''' Return galsim SED obtained by interpolating between SEDs for nearest mjds among the templates ''' - mjd_ix_l,mjd_ix_u,mjd_fraction = self._find_mjd_interval(mjd) + mjd_ix_l, mjd_ix_u, mjd_fraction = self._find_mjd_interval(mjd) with h5py.File(self._belongs_to._SED_file, 'r') as f: if self._mjds is None or self._lambda is None: diff --git a/skycatalogs/objects/sncosmo_object.py b/skycatalogs/objects/sncosmo_object.py index a1695b1d..78cbec93 100644 --- a/skycatalogs/objects/sncosmo_object.py +++ b/skycatalogs/objects/sncosmo_object.py @@ -5,8 +5,10 @@ __all__ = ['SncosmoObject'] + class SncosmoObject(BaseObject): _type_name = 'sncosmo' + def _get_sed(self, mjd=None): params = self.get_native_attribute('salt2_params') sn = SncosmoModel(params=params) @@ -27,7 +29,8 @@ def get_observer_sed_component(self, component, mjd=None): return sed def get_LSST_flux(self, band, sed=None, mjd=None): - if not band in LSST_BANDS: - return None + # if band not in BaseObject.LSST_BANDS: + # return None - return self.get_flux(lsst_bandpasses[band], sed=sed, mjd=mjd) + # return self.get_flux(lsst_bandpasses[band], sed=sed, mjd=mjd) + return super().get_LSST_flux(band, sed=sed, cache=False, mjd=mjd) diff --git a/skycatalogs/objects/star_object.py b/skycatalogs/objects/star_object.py index f7b17754..eb481263 100644 --- a/skycatalogs/objects/star_object.py +++ b/skycatalogs/objects/star_object.py @@ -5,16 +5,17 @@ __all__ = ['StarObject'] + class StarObject(BaseObject): + _type_name = 'star' + def _get_sed(self, mjd=None, redshift=0): ''' We'll need mjd when/if variable stars are supported. For now it's ignored. ''' mag_norm = self.get_native_attribute('magnorm') - rel_path = self.get_native_attribute('sed_filepath') - fpath = os.path.join(os.getenv('SIMS_SED_LIBRARY_DIR'), self.get_native_attribute('sed_filepath')) diff --git a/skycatalogs/readers/parquet_reader.py b/skycatalogs/readers/parquet_reader.py index 239206f6..f64af278 100644 --- a/skycatalogs/readers/parquet_reader.py +++ b/skycatalogs/readers/parquet_reader.py @@ -1,12 +1,11 @@ from collections import OrderedDict -import pyarrow as pa import pyarrow.parquet as pq import numpy as np import numpy.ma as ma -import warnings __all__ = ['ParquetReader'] + class ParquetReader: ''' Handle reads for a particular file diff --git a/skycatalogs/scripts/adjust_snana.py b/skycatalogs/scripts/adjust_snana.py index b5a41691..4022d9f9 100644 --- a/skycatalogs/scripts/adjust_snana.py +++ b/skycatalogs/scripts/adjust_snana.py @@ -10,11 +10,13 @@ Make a new skyCatalogs parquet file from an old, adding columns for MW_av, MW_rv ''') -parser.add_argument('indir', help='Directory containing input file(s). Required') +parser.add_argument('indir', + help='Directory containing input file(s). Required') parser.add_argument('outdir', help='Directory for output. Required') parser.add_argument('--pixels', type=int, nargs='*', default=[], help='healpix pixels for which new files will be created. Required') -parser.add_argument('--starts-with', help='That part of the filename preceding healpixel', +parser.add_argument('--starts-with', + help='That part of the filename preceding healpixel', default='snana_') args = parser.parse_args() diff --git a/skycatalogs/scripts/create_sc.py b/skycatalogs/scripts/create_sc.py index 65c381d3..9613d3f9 100644 --- a/skycatalogs/scripts/create_sc.py +++ b/skycatalogs/scripts/create_sc.py @@ -25,17 +25,26 @@ parser.add_argument('--pixels', type=int, nargs='*', default=[9556], help='healpix pixels for which catalogs will be created') parser.add_argument('--skycatalog-root', - help='Root directory for sky catalogs, typically site-dependent. If not specified, use value of environment variable SKYCATALOG_ROOT') -parser.add_argument('--catalog-dir', '--cat-dir', help='directory for output files relative to skycatalog_root', + help='''Root directory for sky catalogs, typically + site-dependent. If not specified, use value of + environment variable SKYCATALOG_ROOT''') +parser.add_argument('--catalog-dir', '--cat-dir', + help='output file directory relative to skycatalog_root', default='.') -parser.add_argument('--sed-subdir', help='subdirectory to prepend to paths of galaxy SEDs as written to the sky catalog', default='galaxyTopHatSED') +parser.add_argument('--sed-subdir', + help='''subdirectory to prepend to paths of galaxy SEDs + as written to the sky catalog''', + default='galaxyTopHatSED') parser.add_argument('--log-level', help='controls logging output', - default='INFO', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR']) -parser.add_argument('--galaxy-magnitude-cut', '--gal-mag-cut', default=29.0, type=float, + default='INFO', choices=['DEBUG', 'INFO', 'WARNING', + 'ERROR']) +parser.add_argument('--galaxy-magnitude-cut', '--gal-mag-cut', + default=29.0, type=float, help='Exclude galaxies with r-magnitude above this value') parser.add_argument('--knots-magnitude-cut', default=27.0, type=float, help='Galaxies with i-magnitude above this cut get no knots') -parser.add_argument('--no-knots', action='store_true', help='If supplied omit knots component. Default is False') +parser.add_argument('--no-knots', action='store_true', + help='If supplied omit knots component. Default is False') parser.add_argument('--config-path', default=None, help=''' Output path for config file. If no value, @@ -102,7 +111,7 @@ if args.provenance: provenance = args.provenance else: - provenance=None + provenance = None creator = CatalogCreator(parts, area_partition=None, skycatalog_root=skycatalog_root, diff --git a/skycatalogs/skyCatalogs.py b/skycatalogs/skyCatalogs.py index 0a71360a..4c4db14a 100644 --- a/skycatalogs/skyCatalogs.py +++ b/skycatalogs/skyCatalogs.py @@ -6,11 +6,10 @@ import healpy import numpy as np import numpy.ma as ma -import pyarrow.parquet as pq from astropy import units as u from skycatalogs.objects.base_object import load_lsst_bandpasses from skycatalogs.utils.catalog_utils import CatalogContext -from skycatalogs.objects.base_object import ObjectList, ObjectCollection +from skycatalogs.objects.base_object import ObjectList from skycatalogs.objects.gaia_object import GaiaObject, GaiaCollection from skycatalogs.readers import ParquetReader from skycatalogs.utils.sed_tools import TophatSedFactory @@ -24,6 +23,7 @@ __all__ = ['SkyCatalog', 'open_catalog'] + # This function should maybe be moved to utils def _get_intersecting_hps(hp_ordering, nside, region): ''' @@ -43,7 +43,7 @@ def _get_intersecting_hps(hp_ordering, nside, region): region.dec_max, region.dec_max], lonlat=True) - pixels = healpy.query_polygon(nside, vec, inclusive=True, nest=False) + pixels = healpy.query_polygon(nside, vec, inclusive=True, nest=False) elif isinstance(region, Disk): # Convert inputs to the types query_disk expects center = healpy.pixelfunc.ang2vec(region.ra, region.dec, @@ -64,6 +64,7 @@ def _get_intersecting_hps(hp_ordering, nside, region): pixels.sort() return pixels + def _compress_via_mask(tbl, id_column, region, source_type={'galaxy'}, mjd=None): ''' @@ -110,7 +111,7 @@ def _compress_via_mask(tbl, id_column, region, source_type={'galaxy'}, bnd_box = Box(ra_min, ra_max, dec_min, dec_max) # Compute mask for that box mask = _compute_region_mask(bnd_box, tbl['ra'], tbl['dec']) - if all(mask): # even bounding box doesn't intersect table rows + if all(mask): # even bounding box doesn't intersect table rows if no_obj_type_return: return None, None, None, None else: @@ -156,14 +157,16 @@ def _compress_via_mask(tbl, id_column, region, source_type={'galaxy'}, time_mask = _compute_time_mask(mjd, tbl['start_mjd'], tbl['end_mjd']) ra_compress = ma.array(tbl['ra'], mask=time_mask).compressed() - dec_compress = ma.array(tbl['dec'], mask=time_mask).compressed() + dec_compress = ma.array(tbl['dec'], + mask=time_mask).compressed() id_compress = ma.array(tbl[id_column], mask=time_mask).compressed() return ra_compress, dec_compress, id_compress, time_mask else: - return tbl['ra'], tbl['dec'], tbl[id_column],None + return tbl['ra'], tbl['dec'], tbl[id_column], None else: - return tbl['ra'], tbl['dec'], tbl[id_column],tbl['object_type'],None + return tbl['ra'], tbl['dec'], tbl[id_column], tbl['object_type'], None + def _compute_region_mask(region, ra, dec): ''' @@ -192,20 +195,19 @@ def _compute_region_mask(region, ra, dec): lonlat=True) radius_rad = (region.radius_as * u.arcsec).to_value('radian') - # Rather than comparing arcs, it is equivalent to compare chords # (or square of chord length) - diff = p_vec - c_vec - obj_chord_sq = np.sum(np.square(p_vec - c_vec),axis=1) + obj_chord_sq = np.sum(np.square(p_vec - c_vec), axis=1) # This is to be compared to square of chord for angle a corresponding # to disk radius. That's 4(sin(a/2)^2) - rad_chord_sq = 4 * np.square(np.sin(0.5 * radius_rad) ) + rad_chord_sq = 4 * np.square(np.sin(0.5 * radius_rad)) mask = obj_chord_sq > rad_chord_sq if isinstance(region, PolygonalRegion): mask = region.get_containment_mask(ra, dec, included=False) return mask + def _compute_time_mask(current_mjd, start_mjd, end_mjd): ''' Starting with an existing mask of excluded objects, exclude additional @@ -225,6 +227,7 @@ def _compute_time_mask(current_mjd, start_mjd, end_mjd): return mask + class SkyCatalog(object): ''' A base class with derived classes for galaxies, static (w.r.t. coordinates) @@ -277,19 +280,18 @@ def __init__(self, config, mp=False, skycatalog_root=None, verbose=False, self.verbose = verbose self._validate_config() - # Outer dict: hpid for key. Value is another dict # with keys 'files', 'object_types', each with value another dict # for 'files', map filepath to handle (initially None) # for 'object_types', map object type to filepath self._hp_info = dict() - hps = self._find_all_hps() + _ = self._find_all_hps() # NOTE: the use of TophatSedFactory is appropriate *only* for an # input galaxy catalog with format like cosmoDC2, which includes # definitions of tophat SEDs. A different implementation will # be needed for newer galaxy catalogs - th_parameters = self._config.get_tophat_parameters(); + th_parameters = self._config.get_tophat_parameters() self._observed_sed_factory =\ TophatSedFactory(th_parameters, config['Cosmology']) @@ -363,20 +365,20 @@ def _find_all_hps(self): for f in files: for ot in o_types: # find all keys containing the string 'file_template' - template_keys = [k for k in o_types[ot] if 'file_template' in k] - for k in template_keys: + tmplt_keys = [k for k in o_types[ot] if 'file_template' in k] + for k in tmplt_keys: m = re.fullmatch(o_types[ot][k], f) if m: hp = int(m['healpix']) hp_set.add(hp) if hp not in self._hp_info: - self._hp_info[hp] = {'files' : {f : None}, - 'object_types' : {ot : [f]}} + self._hp_info[hp] = {'files': {f: None}, + 'object_types': {ot: [f]}} else: this_hp = self._hp_info[hp] # Value of 'object_types' is now a list - if f not in this_hp['files'] : + if f not in this_hp['files']: this_hp['files'][f] = None if ot in this_hp['object_types']: this_hp['object_types'][ot].append(f) @@ -503,7 +505,8 @@ def get_object_type_by_region(self, region, object_type, mjd=None): if partition['type'] == 'healpix': hps = self.get_hps_by_region(region, object_type) for hp in hps: - c = self.get_object_type_by_hp(hp, object_type, region, mjd) + c = self.get_object_type_by_hp(hp, object_type, + region, mjd) if len(c) > 0: out_list.append_object_list(c) return out_list @@ -539,10 +542,11 @@ def get_object_type_by_hp(self, hp, object_type, region=None, mjd=None): if 'file_template' in self._config['object_types'][object_type]: f_list = self._hp_info[hp]['object_types'][object_type] elif 'parent' in self._config['object_types'][object_type]: - f_list = self._hp_info[hp]['object_types'][self._config['object_types'][ot]['parent']] + # ##f_list = self._hp_info[hp]['object_types'][self._config['object_types'][ot]['parent']] + f_list = self._hp_info[hp]['object_types'][self._config['object_types'][object_type]['parent']] for f in f_list: - if self._hp_info[hp]['files'][f] is None: # no reader yet + if self._hp_info[hp]['files'][f] is None: # no reader yet full_path = os.path.join(self._cat_dir, f) the_reader = ParquetReader(full_path, mask=None) self._hp_info[hp]['files'][f] = the_reader @@ -578,7 +582,6 @@ def get_object_type_by_hp(self, hp, object_type, region=None, mjd=None): source_type={object_type}, mjd=mjd) if ra_c is not None: - ## new_collection = ObjectCollection(ra_c, dec_c, id_c, new_collection = coll_class(ra_c, dec_c, id_c, object_type, hp, self, region=region, @@ -609,10 +612,10 @@ def get_object_type_by_hp(self, hp, object_type, region=None, mjd=None): return object_list - # For generator version, do this a row group at a time # but if region cut leaves too small a list, read more rowgroups # to achieve a reasonable size list (or exhaust the file) + def get_object_iterator_by_hp(self, hp, obj_type_set=None, max_chunk=None, mjd=None): ''' @@ -648,11 +651,12 @@ def open_catalog(config_file, mp=False, skycatalog_root=None, verbose=False): SkyCatalog ''' # Get LSST bandpasses in case we need to compute fluxes - band_passes = load_lsst_bandpasses() + _ = load_lsst_bandpasses() with open(config_file) as f: return SkyCatalog(yaml.safe_load(f), skycatalog_root=skycatalog_root, mp=mp, verbose=verbose) + if __name__ == '__main__': import time cfg_file_name = 'skyCatalog.yaml' @@ -673,18 +677,17 @@ def open_catalog(config_file, mp=False, skycatalog_root=None, verbose=False): # 55.73604 < ra < 57.563452 # -37.19001 < dec < -35.702481 - cat = open_catalog(cfg_file, skycatalog_root=skycatalog_root) hps = cat._find_all_hps() print('Found {} healpix pixels '.format(len(hps))) - for h in hps: print(h) - + for h in hps: + print(h) ra_min_tract = 55.736 ra_max_tract = 57.564 dec_min_tract = -37.190 dec_max_tract = -35.702 - ##ra_min_small = 56.0 - ##ra_max_small = 56.2 + # ra_min_small = 56.0 + # ra_max_small = 56.2 ra_min_small = 55.9 ra_max_small = 56.1 dec_min_small = -36.2 @@ -715,16 +718,17 @@ def open_catalog(config_file, mp=False, skycatalog_root=None, verbose=False): print('Invoke get_objects_by_region with box region, no gaia') t0 = time.time() object_list = cat.get_objects_by_region(rgn, - obj_type_set={'star','galaxy', + obj_type_set={'star', 'galaxy', 'sncosmo'}) t_done = time.time() print('Took ', t_done - t0) - ##### temporary obj_type_set={'galaxy', 'star'} ) + # #### temporary obj_type_set={'galaxy', 'star'} ) # obj_type_set=set(['galaxy']) ) # Try out get_objects_by_hp with no region - #colls = cat.get_objects_by_hp(9812, None, set(['galaxy']) ) + # colls = cat.get_objects_by_hp(9812, None, set(['galaxy']) ) - print('Number of collections returned for box: ', object_list.collection_count) + print('Number of collections returned for box: ', + object_list.collection_count) print('Object count for box: ', len(object_list)) print('Invoke get_objects_by_region with polygonal region') @@ -765,9 +769,8 @@ def open_catalog(config_file, mp=False, skycatalog_root=None, verbose=False): object_list_diamond.collection_count) print('Object count for diamond: ', len(object_list_diamond)) - - #### TEMP FOR DEBUGGING - ### exit(0) + # ### TEMP FOR DEBUGGING + # ## exit(0) # For now SIMS_SED_LIBRARY_DIR is undefined at SLAC, making it impossible # to get SEDs for stars. So (crudely) determine whether or not @@ -783,14 +786,14 @@ def open_catalog(config_file, mp=False, skycatalog_root=None, verbose=False): print("First object: ") print(c[0], '\nid=', c[0].id, ' ra=', c[0].ra, ' dec=', c[0].dec, ' belongs_index=', c[0]._belongs_index, - ' object_type: ', c[0].object_type ) + ' object_type: ', c[0].object_type) if (n_obj < 3): continue print("Slice [1:3]") slice13 = c[1:3] for o in slice13: - print('id=',o.id, ' ra=',o.ra, ' dec=',o.dec, ' belongs_index=', + print('id=', o.id, ' ra=', o.ra, ' dec=', o.dec, ' belongs_index=', o._belongs_index, ' object_type: ', o.object_type) print(o.object_type) if o.object_type == 'star': @@ -807,7 +810,7 @@ def open_catalog(config_file, mp=False, skycatalog_root=None, verbose=False): print(cmp) if cmp in o.subcomponents: # broken for galaxies currently - ###print(o.get_instcat_entry(component=cmp)) + # print(o.get_instcat_entry(component=cmp)) sed, _ = o._get_sed(cmp) if sed: print('Length of sed table: ', len(sed.wave_list)) @@ -821,12 +824,15 @@ def open_catalog(config_file, mp=False, skycatalog_root=None, verbose=False): print('Simple sed values:') print([sed(w) for w in sed.wave_list]) if write_sed: - o.write_sed('simple_sed.txt', component=cmp) + o.write_sed('simple_sed.txt', + component=cmp) sed_fine, _ = o._get_sed(component=cmp, - resolution=1.0) + resolution=1.0) print('Bin width = 1 nm') - print('Initial wl values', sed_fine.wave_list[:20]) - print('Start at bin 100', sed_fine.wave_list[100:120]) + print('Initial wl values', + sed_fine.wave_list[:20]) + print('Start at bin 100', + sed_fine.wave_list[100:120]) print('Initial values') print([sed_fine(w) for w in sed_fine.wave_list[:20]]) print('Start at bin 100') @@ -851,7 +857,7 @@ def open_catalog(config_file, mp=False, skycatalog_root=None, verbose=False): f = o.get_LSST_flux('i') print(f'Flux for i bandpass: {f}') fluxes = o.get_LSST_fluxes() - for k,v in fluxes.items(): + for k, v in fluxes.items(): print(f'Bandpass {k} has flux {v}') if n_obj > 200: @@ -863,8 +869,8 @@ def open_catalog(config_file, mp=False, skycatalog_root=None, verbose=False): slice_late = c[163994:163997] print('\nobjects indexed 163994 through 163996') for o in slice_late: - print('id=',o.id, ' ra=',o.ra, ' dec=',o.dec, ' belongs_index=', - o._belongs_index) + print('id=', o.id, ' ra=', o.ra, ' dec=', o.dec, + ' belongs_index=', o._belongs_index) print('Total object count: ', len(object_list)) @@ -879,7 +885,6 @@ def open_catalog(config_file, mp=False, skycatalog_root=None, verbose=False): redshift0 = object_list[0].get_native_attribute('redshift') print('First redshift: ', redshift0) - sum = 0 for obj in object_list: sum = sum + 1 @@ -909,15 +914,15 @@ def open_catalog(config_file, mp=False, skycatalog_root=None, verbose=False): for o in segment: print(f'object {o.id} of type {o.object_type} belongs to collection {o._belongs_to}') - #ixes = ([3,5,8],) - ixes = (np.array([3,5,8, 300, 303]),) + # ixes = ([3,5,8],) + ixes = (np.array([3, 5, 8, 300, 303]),) print(f'\nObjects with indexes {ixes[0]}') for o in object_list[ixes]: print(o.id) - print(f'\nObjects in slice [3:9]') + print('\nObjects in slice [3:9]') for o in object_list[3:9]: print(o.id) - print(f'\nObjects in slice [300:304]') + print('\nObjects in slice [300:304]') for o in object_list[300:304]: print(o.id) diff --git a/skycatalogs/utils/SED_parquet.py b/skycatalogs/utils/SED_parquet.py index 88e0ddf6..20b979b6 100644 --- a/skycatalogs/utils/SED_parquet.py +++ b/skycatalogs/utils/SED_parquet.py @@ -4,17 +4,18 @@ import pyarrow.parquet as pq import pandas as pd + def make_parquet(input_path): ''' Given a text file in table format where columns are wavelength (nanometers) and flux, convert to parquet ''' - wv,flux = np.genfromtxt(input_path, unpack=True) - #wv32 = np.array(wv, np.float32) + wv, flux = np.genfromtxt(input_path, unpack=True) + # wv32 = np.array(wv, np.float32) df = pd.DataFrame({'wavelength': wv, 'flux': flux}) - #df = pd.DataFrame({'wavelength': wv32, 'flux': flux}) + # df = pd.DataFrame({'wavelength': wv32, 'flux': flux}) # This method produces a file somewhat larger than if we left wv alone table = pa.Table.from_pandas(df) @@ -25,6 +26,7 @@ def make_parquet(input_path): out_path = input_path + '.parquet' pq.write_table(table, out_path) + if __name__ == '__main__': if len(sys.argv) == 1: print('Requires filepath argument') @@ -32,7 +34,7 @@ def make_parquet(input_path): fname = sys.argv[1] - print('Called with filepath argument ',fname) + print('Called with filepath argument ', fname) mode = 'parquet' diff --git a/skycatalogs/utils/add_extinction.py b/skycatalogs/utils/add_extinction.py index 3ab7d91a..994120f7 100644 --- a/skycatalogs/utils/add_extinction.py +++ b/skycatalogs/utils/add_extinction.py @@ -1,10 +1,11 @@ import os -import numpy as np import pyarrow.parquet as pq import pyarrow as pa import pandas as pd -from skycatalogs.utils.creator_utils import make_MW_extinction_av, make_MW_extinction_rv +from skycatalogs.utils.creator_utils import make_MW_extinction_av +from skycatalogs.utils.creator_utils import make_MW_extinction_rv + class AddExtinction(): def __init__(self, in_dir, out_dir, starts_with): @@ -23,14 +24,13 @@ def __init__(self, in_dir, out_dir, starts_with): self._starts_with = starts_with def write(self, pixel): - infile = pq.ParquetFile(os.path.join(self._in_dir, - f'{self._starts_with}{str(pixel)}.parquet')) + fname = f'{self._starts_with}{str(pixel)}.parquet' + infile = pq.ParquetFile(os.path.join(self._in_dir, fname)) arrow_schema = (infile.schema).to_arrow_schema() out_schema = arrow_schema.append(pa.field('MW_av', pa.float32())) out_schema = out_schema.append(pa.field('MW_rv', pa.float32())) - writer = pq.ParquetWriter(os.path.join(self._out_dir, - f'{self._starts_with}{str(pixel)}.parquet'), + writer = pq.ParquetWriter(os.path.join(self._out_dir, fname), out_schema) n_row_group = infile.metadata.num_row_groups @@ -39,8 +39,10 @@ def write(self, pixel): tbl = infile.read_row_group(g) for c in arrow_schema.names: out_dict[c] = [i.as_py() for i in tbl[c]] - out_dict['MW_av'] = make_MW_extinction_av(out_dict['ra'], out_dict['dec']) - out_dict['MW_rv'] = make_MW_extinction_rv(out_dict['ra'], out_dict['dec']) + out_dict['MW_av'] = make_MW_extinction_av(out_dict['ra'], + out_dict['dec']) + out_dict['MW_rv'] = make_MW_extinction_rv(out_dict['ra'], + out_dict['dec']) out_df = pd.DataFrame.from_dict(out_dict) out_table = pa.Table.from_pandas(out_df, schema=out_schema) writer.write_table(out_table) diff --git a/skycatalogs/utils/catalog_utils.py b/skycatalogs/utils/catalog_utils.py index 9884c2be..2a11237a 100644 --- a/skycatalogs/utils/catalog_utils.py +++ b/skycatalogs/utils/catalog_utils.py @@ -1,5 +1,7 @@ __all__ = ['CatalogContext'] from skycatalogs.objects.base_object import ObjectCollection + + class CatalogContext: def __init__(self, the_sky_cat): global sky_cat @@ -18,9 +20,9 @@ def __init__(self, the_sky_cat): def register_source_type(self, name, object_class, collection_class=ObjectCollection, custom_load=False): - self._source_type_dict[name] = {'object_class' : object_class, - 'collection_class' : collection_class, - 'custom_load' : custom_load} + self._source_type_dict[name] = {'object_class': object_class, + 'collection_class': collection_class, + 'custom_load': custom_load} def lookup_source_type(self, name): if name in self._source_type_dict: @@ -36,6 +38,7 @@ def lookup_collection_type(self, name): return None else: return None + def use_custom_load(self, name): if name in self._source_type_dict: return self._source_type_dict[name]['custom_load'] diff --git a/skycatalogs/utils/config_utils.py b/skycatalogs/utils/config_utils.py index 17c0fb3a..00450b89 100644 --- a/skycatalogs/utils/config_utils.py +++ b/skycatalogs/utils/config_utils.py @@ -2,18 +2,19 @@ import yaml import git import logging -from jsonschema import validate - -from .exceptions import NoSchemaVersionError, ConfigDuplicateKeyError +from .exceptions import ConfigDuplicateKeyError +# import jsonschema from collections import namedtuple __all__ = ['Config', 'open_config_file', 'Tophat', 'create_config', 'assemble_SED_models', 'assemble_MW_extinction', - 'assemble_cosmology', 'assemble_object_types', 'assemble_provenance', - 'assemble_variability_models', 'write_yaml', 'CURRENT_SCHEMA_VERSION'] + 'assemble_cosmology', 'assemble_object_types', + 'assemble_provenance', 'assemble_variability_models', 'write_yaml', + 'CURRENT_SCHEMA_VERSION'] + +CURRENT_SCHEMA_VERSION = '1.2.0' -CURRENT_SCHEMA_VERSION='1.2.0' def open_config_file(config_file): ''' @@ -22,6 +23,7 @@ def open_config_file(config_file): with open(config_file) as f: return Config(yaml.safe_load(f)) + Tophat = namedtuple('Tophat', ['start', 'width']) @@ -46,16 +48,19 @@ class DelegatorBase: def _delegate(self): pub = [o for o in dir(self.default) if not o.startswith('_')] return pub + def __getattr__(self, k): if k in self._delegate: return getattr(self.default, k) raise AttributeError(k) + def __dir__(self): return _custom_dir(self, self._delegate) def __init__(self): pass + class Config(DelegatorBase): ''' A wrapper around the dict which is the contents of a Sky Catalog @@ -119,7 +124,7 @@ def get_tophat_parameters(self): return None raw_bins = self._cfg['SED_models']['tophat']['bins'] - return [ Tophat(b[0], b[1]) for b in raw_bins] + return [Tophat(b[0], b[1]) for b in raw_bins] def get_config_value(self, key_path, silent=False): ''' @@ -138,7 +143,7 @@ def get_config_value(self, key_path, silent=False): path_items = key_path.split('/') d = self._cfg for i in path_items[:-1]: - if not i in d.keys(): + if i not in d.keys(): if silent: return None raise ValueError(f'Item {i} not found') @@ -147,46 +152,6 @@ def get_config_value(self, key_path, silent=False): raise ValueError(f'intermediate {d} is not a dict') return d[path_items[-1]] - def validate(self, schema=None): - ''' - Parameters - ---------- - schema Identifies schema to validate against. - If string, open file with this file path - If dict, use "as is" - If None, attempt to deduce schema version (and hence - file path) from schema_version keyword - Returns: None - Raise exception if validation fails for any reason: - skycatalogs.exception.NoSchema if schema specification - can't be found - OSError if schema file can't be read - yaml.YAMLerror if schema can't be loaded - jsonschema.exceptions.SchemaError if schema is invalid - jsonschema.exceptions.ValidationError otherwise - ''' - fpath = None - if schema is None: - if 'schema_version' not in self._cfg: - raise NoSchemaVersionError - fpath = _find_schema_path(self._cfg["schema_version"]) - if not fpath: - raise NoSchemaVersionError('Schema file not found') - elif isinstance(schema, string): - fpath = schema - if fpath: - try: - f = open(fpath) - sch = yaml.safe_load(f) - except OSError as e: - raise NoSchemaVersionError('Schema file not found or unreadable') - except yaml.YAMLError as ye: - raise ye - if isinstance(schema, dict): - sch = schema - - jsonschema.validate(self._cfg, schema.dict) - def add_key(self, k, v): ''' Parameters @@ -214,7 +179,7 @@ def write_config(self, dirpath, filename=None, overwrite=False): ------ Full path of output config ''' - ###self.validate() skip for now + # ##self.validate() skip for now if not filename: filename = self._cfg['catalog_name'] + '.yaml' @@ -224,49 +189,43 @@ def write_config(self, dirpath, filename=None, overwrite=False): def write_yaml(input_dict, outpath, overwrite=False, logname=None): - if not overwrite: - try: - with open(outpath, mode='x') as f: - yaml.dump(input_dict, f) - except FileExistsError: - if logname: - logger = logging.getLogger(logname) - logger.warning('Config.write_yaml: Will not overwrite pre-existing config file ' + outpath) - else: - print('Config.write_yaml: Will not overwrite pre-existing config file ' + outpath) - return - - else: - with open(outpath, mode='w') as f: + if not overwrite: + try: + with open(outpath, mode='x') as f: yaml.dump(input_dict, f) + except FileExistsError: + txt = 'Config.write_yaml: Will not overwrite pre-existing config ' + if logname: + logger = logging.getLogger(logname) + logger.warning(txt + outpath) + else: + print(txt + outpath) + return + else: + with open(outpath, mode='w') as f: + yaml.dump(input_dict, f) - return outpath - + return outpath -def _find_schema_path(schema_spec): - ''' - Given a schema version specification, return the file path - where the file describing it belongs - ''' - fname = f'skycatalogs_schema_{self._cfg["schema_spec"]}' - here = os.path.dirname(__file__) - return os.path.join(here, '../../cfg', fname) def create_config(catalog_name, logname=None): - return Config({'catalog_name' : catalog_name}, logname) -# 'schema_version' : schema_version, -# 'code_version' : desc.skycatalogs.__version__}, logname) + return Config({'catalog_name': catalog_name}, logname) +# 'schema_version': schema_version, +# 'code_version': desc.skycatalogs.__version__}, logname) + def assemble_cosmology(cosmology): - d = {k : cosmology.__getattribute__(k) for k in ('Om0', 'Ob0', 'sigma8', - 'n_s')} + d = {k: cosmology.__getattribute__(k) for k in ('Om0', 'Ob0', 'sigma8', + 'n_s')} d['H0'] = float(cosmology.H0.value) return d + def assemble_MW_extinction(): - av = {'mode' : 'data'} - rv = {'mode' : 'constant', 'value' : 3.1} - return {'r_v' : rv, 'a_v' : av} + av = {'mode': 'data'} + rv = {'mode': 'constant', 'value': 3.1} + return {'r_v': rv, 'a_v': av} + def assemble_object_types(pkg_root, galaxy_nside=32): ''' @@ -279,18 +238,20 @@ def assemble_object_types(pkg_root, galaxy_nside=32): d['object_types']['galaxy']['area_partition']['nside'] = galaxy_nside return d['object_types'] + def assemble_SED_models(bins): - tophat_d = { 'units' : 'angstrom', 'bin_parameters' : ['start', 'width']} + tophat_d = {'units': 'angstrom', 'bin_parameters': ['start', 'width']} tophat_d['bins'] = bins - file_nm_d = {'units' : 'nm'} - return {'tophat' : tophat_d, 'file_nm' : file_nm_d} + file_nm_d = {'units': 'nm'} + return {'tophat': tophat_d, 'file_nm': file_nm_d} + def assemble_provenance(pkg_root, inputs={}, schema_version=None): if not schema_version: schema_version = CURRENT_SCHEMA_VERSION import skycatalogs - version_d = {'schema_version' : schema_version} + version_d = {'schema_version': schema_version} if '__version__' in dir(skycatalogs): code_version = skycatalogs.__version__ else: @@ -301,7 +262,6 @@ def assemble_provenance(pkg_root, inputs={}, schema_version=None): has_uncommited = repo.is_dirty() has_untracked = (len(repo.untracked_files) > 0) - git_d = {} git_d['git_hash'] = repo.commit().hexsha git_d['git_branch'] = repo.active_branch.name @@ -315,16 +275,18 @@ def assemble_provenance(pkg_root, inputs={}, schema_version=None): git_d['git_status'] = status if inputs: - return {'versioning' : version_d,'skyCatalogs_repo' : git_d, - 'inputs' : inputs} + return {'versioning': version_d, 'skyCatalogs_repo': git_d, + 'inputs': inputs} else: - return{'versioning' : version_d, 'skyCatalogs_repo' : git_d} + return{'versioning': version_d, 'skyCatalogs_repo': git_d} + # In config just keep track of models by object type. Information # about the parameters they require is internal to the code. _AGN_MODELS = ['agn_random_walk'] _SNCOSMO_MODELS = ['sn_salt2_extended'] + def assemble_variability_models(object_types): ''' Add information about all known variability models for supplied object diff --git a/skycatalogs/utils/creator_utils.py b/skycatalogs/utils/creator_utils.py index 45167234..58799501 100644 --- a/skycatalogs/utils/creator_utils.py +++ b/skycatalogs/utils/creator_utils.py @@ -3,6 +3,8 @@ _Av_adjustment = 2.742 _MW_rv_constant = 3.1 + + def make_MW_extinction_av(ra, dec): ''' Given arrays of ra & dec, create a MW Av column corresponding to V-band @@ -10,7 +12,8 @@ def make_MW_extinction_av(ra, dec): See "Plotting Dust Maps" example in https://dustmaps.readthedocs.io/en/latest/examples.html - The coefficient _Av_adjustment comes Table 6 in Schlafly & Finkbeiner (2011) + The coefficient _Av_adjustment comes from Table 6 in + Schlafly & Finkbeiner (2011) See http://iopscience.iop.org/0004-637X/737/2/103/article#apj398709t6 Parameters @@ -21,9 +24,10 @@ def make_MW_extinction_av(ra, dec): ''' sfd = SFDQuery() - ebv_raw = np.array(sfd.query_equ(ra, dec)) + ebv_raw = np.array(sfd.query_equ(np.array(ra), np.array(dec))) return _Av_adjustment * ebv_raw + def make_MW_extinction_rv(ra, dec): return np.full_like(np.array(ra), _MW_rv_constant) diff --git a/skycatalogs/utils/exceptions.py b/skycatalogs/utils/exceptions.py index ac22aca3..b49a74aa 100644 --- a/skycatalogs/utils/exceptions.py +++ b/skycatalogs/utils/exceptions.py @@ -1,9 +1,11 @@ __all__ = ['SkyCatalogsException', 'NoSchemaVersionError', 'ConfigDuplicateKeyError'] + class SkyCatalogsException(Exception): pass + class NoSchemaVersionError(SkyCatalogsException): def __init__(self, msg): @@ -12,11 +14,13 @@ def __init__(self, msg): self.msg = msg super().__init__(self.msg) + class ConfigDuplicateKeyError(SkyCatalogsException): def __init__(self, key): self.msg = f'Cannot add duplicate key {key} to config' super().__init__(self.msg) + class SkyCatalogsRuntimeError(SkyCatalogsException): def __init__(self, msg): if not msg: diff --git a/skycatalogs/utils/make_fake.py b/skycatalogs/utils/make_fake.py deleted file mode 100644 index 47a52041..00000000 --- a/skycatalogs/utils/make_fake.py +++ /dev/null @@ -1,29 +0,0 @@ -import io -import argparse - -''' -Read in a bit of old instance catalogs (single healpixel; bulge, disk and -knots). For each of the three: -1. Transpose; store in dataframe. Columns are -prefix (string 'object') ignore -id (identifies galaxy) use -ra,dec use -magnitude (for band observed in this visit) ?? -sedfilepath probably ignore -redshift use -gamma1,gamma2,kappa use -raOffset,decOffset ignore -spatialmodel goes in config -majorAxis,minorAxis,positionAngle,sindex use (spatial_params) -internalExtinctionModel goes in config -internalAv,internalRv use -galacticExtinctionModel goes in config -galacticAv,galacticRv use - -2. Add column for component type -3. Add a column for tophat values (make them up?) -4. Add empty column for rel. SED filepath(s) -5. Sort by galaxy id -6. write out as parquet - -''' diff --git a/skycatalogs/utils/parquet_schema_utils.py b/skycatalogs/utils/parquet_schema_utils.py index 4b11e2f8..5d671cfc 100644 --- a/skycatalogs/utils/parquet_schema_utils.py +++ b/skycatalogs/utils/parquet_schema_utils.py @@ -4,14 +4,15 @@ __all__ = ['make_galaxy_schema', 'make_galaxy_flux_schema', 'make_pointsource_schema', 'make_star_flux_schema'] + # This schema is not the same as the one taken from the data, # probably because of the indexing in the schema derived from a pandas df. def make_galaxy_schema(logname, sed_subdir=False, knots=True): fields = [pa.field('galaxy_id', pa.int64()), - pa.field('ra', pa.float64() , True), -## metadata={"units" : "radians"}), - pa.field('dec', pa.float64() , True), -## metadata={"units" : "radians"}), + pa.field('ra', pa.float64(), True), + # metadata={"units" : "radians"}), + pa.field('dec', pa.float64(), True), + # metadata={"units" : "radians"}), pa.field('redshift', pa.float64(), True), pa.field('redshift_hubble', pa.float64(), True), pa.field('peculiar_velocity', pa.float64(), True), @@ -24,7 +25,7 @@ def make_galaxy_schema(logname, sed_subdir=False, knots=True): pa.field('size_disk_true', pa.float32(), True), pa.field('size_minor_disk_true', pa.float32(), True), pa.field('sersic_disk', pa.float32(), True), -## pa.field('position_angle_unlensed', pa.float64(), True), + # Depending on value of --dc2-like option, value for # ellipticity_2_true column will differ pa.field('ellipticity_1_disk_true', pa.float64(), True), @@ -44,9 +45,9 @@ def make_galaxy_schema(logname, sed_subdir=False, knots=True): logger.debug("knots requested") fields.append(pa.field('sed_val_knots', pa.list_(pa.float64()), True)) - ### For sizes API can alias to disk sizes - ### position angle, shears and convergence are all - ### galaxy-wide quantities. + # For sizes API can alias to disk sizes + # position angle, shears and convergence are all + # galaxy-wide quantities. fields.append(pa.field('n_knots', pa.float32(), True)) fields.append(pa.field('knots_magnorm', pa.float64(), True)) @@ -60,6 +61,7 @@ def make_galaxy_schema(logname, sed_subdir=False, knots=True): logger.debug(debug_out) return pa.schema(fields) + def make_galaxy_flux_schema(logname): ''' Will make a separate parquet file with lsst flux for each band @@ -69,14 +71,15 @@ def make_galaxy_flux_schema(logname): logger.debug('Creating galaxy flux schema') fields = [pa.field('galaxy_id', pa.int64()), # should flux fields be named e.g. lsst_cmodel_flux_u? - pa.field('lsst_flux_u', pa.float32() , True), - pa.field('lsst_flux_g', pa.float32() , True), - pa.field('lsst_flux_r', pa.float32() , True), - pa.field('lsst_flux_i', pa.float32() , True), - pa.field('lsst_flux_z', pa.float32() , True), - pa.field('lsst_flux_y', pa.float32() , True)] + pa.field('lsst_flux_u', pa.float32(), True), + pa.field('lsst_flux_g', pa.float32(), True), + pa.field('lsst_flux_r', pa.float32(), True), + pa.field('lsst_flux_i', pa.float32(), True), + pa.field('lsst_flux_z', pa.float32(), True), + pa.field('lsst_flux_y', pa.float32(), True)] return pa.schema(fields) + def make_star_flux_schema(logname): ''' Will make a separate parquet file with lsst flux for each band @@ -85,14 +88,15 @@ def make_star_flux_schema(logname): logger = logging.getLogger(logname) logger.debug('Creating star flux schema') fields = [pa.field('id', pa.string()), - pa.field('lsst_flux_u', pa.float32() , True), - pa.field('lsst_flux_g', pa.float32() , True), - pa.field('lsst_flux_r', pa.float32() , True), - pa.field('lsst_flux_i', pa.float32() , True), - pa.field('lsst_flux_z', pa.float32() , True), - pa.field('lsst_flux_y', pa.float32() , True)] + pa.field('lsst_flux_u', pa.float32(), True), + pa.field('lsst_flux_g', pa.float32(), True), + pa.field('lsst_flux_r', pa.float32(), True), + pa.field('lsst_flux_i', pa.float32(), True), + pa.field('lsst_flux_z', pa.float32(), True), + pa.field('lsst_flux_y', pa.float32(), True)] return pa.schema(fields) + def make_pointsource_schema(): ''' Ultimately should handle stars both static and variable, SN, and AGN @@ -121,9 +125,10 @@ def make_pointsource_schema(): pa.field('parallax', pa.float64(), True), pa.field('variability_model', pa.string(), True), pa.field('salt2_params', pa.struct(salt2_fields), True) - ] + ] return pa.schema(fields) + def make_pointsource_flux_schema(logname): ''' Will make a separate parquet file with lsst flux for each band @@ -134,13 +139,13 @@ def make_pointsource_flux_schema(logname): logger = logging.getLogger(logname) logger.debug('Creating pointsource flux schema') fields = [pa.field('id', pa.string()), - pa.field('lsst_flux_u', pa.float32() , True), - pa.field('lsst_flux_g', pa.float32() , True), - pa.field('lsst_flux_r', pa.float32() , True), - pa.field('lsst_flux_i', pa.float32() , True), - pa.field('lsst_flux_z', pa.float32() , True), - pa.field('lsst_flux_y', pa.float32() , True), - pa.field('mjd', pa.float64() , True)] + pa.field('lsst_flux_u', pa.float32(), True), + pa.field('lsst_flux_g', pa.float32(), True), + pa.field('lsst_flux_r', pa.float32(), True), + pa.field('lsst_flux_i', pa.float32(), True), + pa.field('lsst_flux_z', pa.float32(), True), + pa.field('lsst_flux_y', pa.float32(), True), + pa.field('mjd', pa.float64(), True)] return pa.schema(fields) return pa.schema(fields) diff --git a/skycatalogs/utils/sed_tools.py b/skycatalogs/utils/sed_tools.py index d03033ef..4652d610 100644 --- a/skycatalogs/utils/sed_tools.py +++ b/skycatalogs/utils/sed_tools.py @@ -3,15 +3,14 @@ from astropy import units as u from astropy.cosmology import FlatLambdaCDM import astropy.constants -import h5py -import pandas as pd import numpy as np -import numpy.ma as ma from dust_extinction.parameter_averages import F19 import galsim __all__ = ['TophatSedFactory', 'MilkyWayExtinction', 'get_star_sed_path'] + + class TophatSedFactory: ''' Used for modeling cosmoDC2 galaxy SEDs, which are represented with @@ -97,7 +96,7 @@ def create(self, Lnu, redshift_hubble, redshift, resolution=None): ''' # Compute Llambda in units of W/nm Llambda = (Lnu*self._to_W_per_Hz*(self.nu[:-1] - self.nu[1:]) - /(self.wl[1:] - self.wl[:-1])) + / (self.wl[1:] - self.wl[:-1])) # Fill the arrays for the galsim.LookupTable. Prepend # zero-valued bins down to mix extinction wl to handle redshifts z > 2. @@ -119,7 +118,8 @@ def create(self, Lnu, redshift_hubble, redshift, resolution=None): if resolution: wl_min = min(self.wl_deltas) wl_max = max(self.wl_deltas) - wl_res = np.linspace(wl_min, wl_max, int((wl_max - wl_min)/resolution)) + wl_res = np.linspace(wl_min, wl_max, + int((wl_max - wl_min)/resolution)) flambda_res = [lut(wl) for wl in wl_res] lut = galsim.LookupTable(wl_res, flambda_res, interpolant='linear') @@ -136,6 +136,7 @@ def magnorm(self, tophat_values, z_H): with np.errstate(divide='ignore', invalid='ignore'): return -2.5*np.log10(Fnu/one_Jy) + 8.90 + class MilkyWayExtinction: ''' Applies extinction to a SED @@ -171,9 +172,10 @@ def extinguish(self, sed, mwAv): return sed -_standard_dict = {'lte' : 'starSED/phoSimMLT', - 'bergeron' : 'starSED/wDs', - 'km|kp' : 'starSED/kurucz'} +_standard_dict = {'lte': 'starSED/phoSimMLT', + 'bergeron': 'starSED/wDs', + 'km|kp': 'starSED/kurucz'} + def get_star_sed_path(filename, name_to_folder=_standard_dict): ''' @@ -182,7 +184,8 @@ def get_star_sed_path(filename, name_to_folder=_standard_dict): Parameters ---------- - filename list of strings. Usually full filename but may be missing final ".gz" + filename list of strings. Usually full filename but may be + missing final ".gz" name_to_folder dict mapping regular expression (to be matched with filename) to relative path for containing directory @@ -191,13 +194,13 @@ def get_star_sed_path(filename, name_to_folder=_standard_dict): Full path for file, relative to SIMS_SED_LIBRARY_DIR ''' - compiled = { re.compile(k) : v for (k, v) in name_to_folder.items()} + compiled = {re.compile(k): v for (k, v) in name_to_folder.items()} path_list = [] for f in filename: m = None matched = False - for k,v in compiled.items(): + for k, v in compiled.items(): f = f.strip() m = k.match(f) if m: diff --git a/skycatalogs/utils/shapes.py b/skycatalogs/utils/shapes.py index a35b37bd..7a739997 100644 --- a/skycatalogs/utils/shapes.py +++ b/skycatalogs/utils/shapes.py @@ -1,7 +1,7 @@ from collections import namedtuple import numpy as np from astropy import units as u - +from lsst.sphgeom import ConvexPolygon, UnitVector3d, LonLat __all__ = ['Box', 'Disk', 'PolygonalRegion'] @@ -10,7 +10,7 @@ # radius is measured in arcseconds Disk = namedtuple('Disk', ['ra', 'dec', 'radius_as']) -from lsst.sphgeom import ConvexPolygon, UnitVector3d, LonLat + class PolygonalRegion: def __init__(self, vertices_radec=None, convex_polygon=None): diff --git a/skycatalogs/utils/translate_utils.py b/skycatalogs/utils/translate_utils.py index 2ac88f53..dd3c47f9 100644 --- a/skycatalogs/utils/translate_utils.py +++ b/skycatalogs/utils/translate_utils.py @@ -1,4 +1,4 @@ -from collections import namedtuple, OrderedDict +from collections import namedtuple from enum import Enum import numpy as np @@ -7,27 +7,36 @@ 'form_star_instance_columns', 'form_cmp_instance_columns'] -STAR_FMT = '{:s} {:s} {:.14f} {:.14f} {:.8f} {:s} {:d} {:d} {:d} {:d} {:d} {:d} {:s} {:s} {:s} {:.8f} {:f}' +STAR_FMT = '''{:s} {:s} {:.14f} {:.14f} {:.8f} {:s} {:d} {:d} {:d} {:d} {:d} +{:d} {:s} {:s} {:s} {:.8f} {:f}''' + +CMP_FMT = '''{:s} {:s} {:.14f} {:.14f} {:.8f} {:s} {:.9g} {:.9g} {:.9g} {:.9g} + {:d} {:d} {:s} {:.9g} {:.9g} {:f} {:.0f} {:s} {:s} {:.8f} {:f}''' -CMP_FMT = '{:s} {:s} {:.14f} {:.14f} {:.8f} {:s} {:.9g} {:.9g} {:.9g} {:.9g} {:d} {:d} {:s} {:.9g} {:.9g} {:f} {:.0f} {:s} {:s} {:.8f} {:f}' def form_star_instance_columns(band): - star_instance = [column_finder('prefix', SourceType.FIXED, ('object', np.dtype('U6'))), - #column_finder('uniqueId', SourceType.DATA, 'id'), + star_instance = [column_finder('prefix', SourceType.FIXED, + ('object', np.dtype('U6'))), + # column_finder('uniqueId', SourceType.DATA, 'id'), column_finder('uniquePsId', SourceType.COMPUTE, ['id']), column_finder('raPhoSim', SourceType.DATA, 'ra'), column_finder('decPhoSim', SourceType.DATA, 'dec'), - column_finder('maskedMagNorm', SourceType.DATA, 'magnorm'), - column_finder('sedFilepath',SourceType.DATA, 'sed_filepath'), + column_finder('maskedMagNorm', SourceType.DATA, + 'magnorm'), + column_finder('sedFilepath', SourceType.DATA, + 'sed_filepath'), column_finder('redshift', SourceType.FIXED, (0, int)), column_finder('gamma1', SourceType.FIXED, (0, int)), column_finder('gamma2', SourceType.FIXED, (0, int)), column_finder('kappa', SourceType.FIXED, (0, int)), column_finder('raOffset', SourceType.FIXED, (0, int)), column_finder('decOffset', SourceType.FIXED, (0, int)), - column_finder('spatialmodel', SourceType.FIXED, ('point', np.dtype('U5'))), - column_finder('internalExtinctionModel', SourceType.FIXED, ('none', np.dtype('U4'))), - column_finder('galacticExtinctionModel', SourceType.CONFIG, + column_finder('spatialmodel', SourceType.FIXED, + ('point', np.dtype('U5'))), + column_finder('internalExtinctionModel', SourceType.FIXED, + ('none', np.dtype('U4'))), + column_finder('galacticExtinctionModel', + SourceType.CONFIG, ('object_types/star/MW_extinction', str)), column_finder('galactivAv', SourceType.DATA, 'MW_av'), @@ -35,13 +44,18 @@ def form_star_instance_columns(band): ('MW_extinction_values/r_v/value', float))] return star_instance + def _form_knots_instance_columns(cmp, band): - cmp_instance = [column_finder('prefix', SourceType.FIXED, ('object', np.dtype('U6'))), - column_finder('uniqueId', SourceType.COMPUTE, ['galaxy_id', f'{cmp}']), + cmp_instance = [column_finder('prefix', SourceType.FIXED, ('object', + np.dtype('U6'))), + column_finder('uniqueId', SourceType.COMPUTE, + ['galaxy_id', f'{cmp}']), column_finder('raPhoSim', SourceType.DATA, 'ra'), column_finder('decPhoSim', SourceType.DATA, 'dec'), - column_finder('phoSimMagNorm', SourceType.DATA, 'knots_magnorm'), - column_finder('sedFilepath',SourceType.COMPUTE, [f'sed_val_{cmp}','redshift_hubble']), + column_finder('phoSimMagNorm', SourceType.DATA, + 'knots_magnorm'), + column_finder('sedFilepath', SourceType.COMPUTE, + [f'sed_val_{cmp}', 'redshift_hubble']), column_finder('redshift', SourceType.DATA, 'redshift'), column_finder('gamma1', SourceType.DATA, 'shear_1'), column_finder('gamma2', SourceType.DATA, 'shear_2'), @@ -49,31 +63,41 @@ def _form_knots_instance_columns(cmp, band): column_finder('raOffset', SourceType.FIXED, (0, int)), column_finder('decOffset', SourceType.FIXED, (0, int)), column_finder('spatialmodel', SourceType.CONFIG, - (f'object_types/{cmp}_basic/spatial_model', 'str')), - column_finder('majorAxis', SourceType.DATA, 'size_disk_true'), - column_finder('minorAxis', SourceType.DATA, 'size_minor_disk_true'), - column_finder('positionAngle', SourceType.DATA, 'position_angle_unlensed'), + (f'object_types/{cmp}_basic/spatial_model', + 'str')), + column_finder('majorAxis', SourceType.DATA, + 'size_disk_true'), + column_finder('minorAxis', SourceType.DATA, + 'size_minor_disk_true'), + column_finder('positionAngle', SourceType.DATA, + 'position_angle_unlensed'), column_finder('sindex', SourceType.DATA, 'n_knots'), - column_finder('internalExtinctionModel', SourceType.FIXED, ('none', np.dtype('U4'))), + column_finder('internalExtinctionModel', SourceType.FIXED, + ('none', np.dtype('U4'))), column_finder('galacticExtinctionModel', SourceType.CONFIG, - (f'object_types/{cmp}_basic/MW_extinction', 'str')), + (f'object_types/{cmp}_basic/MW_extinction', + 'str')), column_finder('galactivAv', SourceType.DATA, 'MW_av'), column_finder('galacticRv', SourceType.CONFIG, ('MW_extinction_values/r_v/value', float))] return cmp_instance -### + def form_cmp_instance_columns(cmp, band): if cmp == 'knots': return _form_knots_instance_columns(cmp, band) - cmp_instance = [column_finder('prefix', SourceType.FIXED, ('object', np.dtype('U6'))), - column_finder('uniqueId', SourceType.COMPUTE, ['galaxy_id', f'{cmp}']), + cmp_instance = [column_finder('prefix', SourceType.FIXED, + ('object', np.dtype('U6'))), + column_finder('uniqueId', SourceType.COMPUTE, + ['galaxy_id', f'{cmp}']), column_finder('raPhoSim', SourceType.DATA, 'ra'), column_finder('decPhoSim', SourceType.DATA, 'dec'), - column_finder('phoSimMagNorm', SourceType.DATA, f'{cmp}_magnorm'), - column_finder('sedFilepath',SourceType.COMPUTE, [f'sed_val_{cmp}','redshift_hubble']), + column_finder('phoSimMagNorm', SourceType.DATA, + f'{cmp}_magnorm'), + column_finder('sedFilepath', SourceType.COMPUTE, + [f'sed_val_{cmp}', 'redshift_hubble']), column_finder('redshift', SourceType.DATA, 'redshift'), column_finder('gamma1', SourceType.DATA, 'shear_1'), column_finder('gamma2', SourceType.DATA, 'shear_2'), @@ -81,23 +105,30 @@ def form_cmp_instance_columns(cmp, band): column_finder('raOffset', SourceType.FIXED, (0, int)), column_finder('decOffset', SourceType.FIXED, (0, int)), column_finder('spatialmodel', SourceType.CONFIG, - (f'object_types/{cmp}_basic/spatial_model', 'str')), - column_finder('majorAxis', SourceType.DATA, f'size_{cmp}_true'), - column_finder('minorAxis', SourceType.DATA, f'size_minor_{cmp}_true'), - column_finder('positionAngle', SourceType.DATA, 'position_angle_unlensed'), + (f'object_types/{cmp}_basic/spatial_model', + 'str')), + column_finder('majorAxis', SourceType.DATA, + f'size_{cmp}_true'), + column_finder('minorAxis', SourceType.DATA, + f'size_minor_{cmp}_true'), + column_finder('positionAngle', SourceType.DATA, + 'position_angle_unlensed'), column_finder('sindex', SourceType.DATA, f'sersic_{cmp}'), - column_finder('internalExtinctionModel', SourceType.FIXED, ('none', np.dtype('U4'))), + column_finder('internalExtinctionModel', SourceType.FIXED, + ('none', np.dtype('U4'))), column_finder('galacticExtinctionModel', SourceType.CONFIG, - (f'object_types/{cmp}_basic/MW_extinction', 'str')), + (f'object_types/{cmp}_basic/MW_extinction', + 'str')), column_finder('galactivAv', SourceType.DATA, 'MW_av'), column_finder('galacticRv', SourceType.CONFIG, ('MW_extinction_values/r_v/value', float))] return cmp_instance -### + column_finder = namedtuple('ColumnFinder', ['instance_name', 'source_type', 'source_parm']) + SourceType = Enum('SourceType', 'DATA CONFIG FIXED COMPUTE') ''' Used in source_type field of column_finder to describe source of each @@ -109,15 +140,17 @@ def form_cmp_instance_columns(cmp, band): COMPUTE Arbitrary computation which may involve any of the above ''' + def check_file(path): '''Look for a file that should not exist''' try: f = open(path, mode='r') - except FileNotFoundError as e: + except FileNotFoundError: return raise ValueError(f'File for {path} already exists') + def write_to_instance(handle, ordered_data_dict, fmt): ''' Parameters @@ -129,6 +162,7 @@ def write_to_instance(handle, ordered_data_dict, fmt): for row in zip(*col_list): handle.write(fmt.format(*row) + '\n') + def form_object_string(obj, band, component): ''' parse columns for this object/component type @@ -179,18 +213,19 @@ def form_object_string(obj, band, component): if cmp not in ['disk', 'bulge', 'knots']: raise ValueError(f'translate_utils.form_object_string: Bad COMPUTE entry {c.instance_name} for component {cmp}') row.append(f'{str(obj.get_native_attribute("galaxy_id"))}_{cmp}') - else: # uniquePsId. Output must be string but input may be int + else: # uniquePsId. Output must be string but input may be int row.append(f'{str(obj.get_native_attribute("id"))}') else: raise ValueError(f'translate_utils.form_object_string: Unknown source type {c.source_type}') return write_to_string(row, fmt) + def write_to_string(row, fmt): ''' Output single string, composed from input values in row. Types of values in row must match those expected by fmt. - Parameters +y Parameters ---------- row list of values fmt fmt string.