Skip to content

Commit

Permalink
address reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanneBogart committed Aug 2, 2024
1 parent 820aa6c commit fc25144
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 69 deletions.
26 changes: 16 additions & 10 deletions skycatalogs/objects/base_config_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@


class BaseConfigFragment():
def __init__(self, prov, object_type_name=None, template_name=None):
def __init__(self, prov, object_type_name=None, template_name=None,
area_partition=None, data_file_type=None):
self._object_type_name = object_type_name
self._prov = prov
self._opt_dict = dict()
if area_partition:
self._opt_dict['area_partition'] = area_partition
if data_file_type:
self._opt_dict['data_file_type'] = data_file_type
self._template_name = template_name
if not template_name:
if object_type_name:
Expand All @@ -42,20 +48,20 @@ def make_fragment(self):
for the object type.
Must be implemented by subclass
'''
raise NotImplementedError("Must be implemented by subclass")
return self.generic_create()

def generic_create(self):
template_path = os.path.join(_TEMPLATE_DIR, self._template_name)
with open(template_path, 'r') as f:
data = yaml.load(f, Loader=yaml.SafeLoader)
if self._opt_dict:
opt = self._opt_dict
other = dict()
for key in opt:
if opt[key] is not None:
other[key] = opt[key]
if len(other.keys()) > 0:
data.update(other)

opt = self._opt_dict
other = dict()
for key in opt:
if opt[key] is not None:
other[key] = opt[key]
if len(other.keys()) > 0:
data.update(other)

data['provenance'] = self._prov
return data
25 changes: 21 additions & 4 deletions skycatalogs/objects/base_object.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Sequence, Iterable
from collections import namedtuple
import os
import logging
import numpy as np
import galsim
from galsim.roman import longwave_bands as roman_longwave_bands
Expand All @@ -26,7 +27,7 @@
# global for easy access for code run within mp


def load_lsst_bandpasses():
def _load_lsst_bandpasses():
'''
Read in lsst bandpasses from standard place, trim, and store in global dict
Returns: The dict and throughputs version
Expand All @@ -48,14 +49,16 @@ def load_lsst_bandpasses():
if os.path.exists(bp_dir):
BaseObject._bp_path = bp_dir
else:
# logger.info('Using galsim built-in bandpasses')
logger = logging.getLogger('skyCatalogs:load_lsst_bandpasses')
logger.warning('Using galsim built-in bandpasses which may not be up to date')
bp_dir = None
if bp_dir:
with open(os.path.join(bp_dir, 'version_info')) as f:
version = f.read().strip()
else:
version = 'galsim_builtin'


for band in LSST_BANDS:
if bp_dir:
bp_full_path = os.path.join(bp_dir, f'total_{band}.dat')
Expand All @@ -76,16 +79,30 @@ def load_lsst_bandpasses():

return lsst_bandpasses, version

def load_lsst_bandpasses():
'''
Read in lsst bandpasses from standard place, trim, and store in global dict
Returns
-------
The bandpasses
'''
return _load_lsst_bandpasses()[0]

def load_roman_bandpasses():
def _load_roman_bandpasses():
'''
Read in Roman bandpasses from standard place, trim, and store in global dict
Returns: The dict
Returns: The dict and version inforation
'''
global roman_bandpasses
roman_bandpasses = roman_getBandpasses()
return roman_bandpasses, 'galsim_builtin'

def load_roman_bandpasses():
'''
Read in Roman bandpasses from standard place, trim, and store in global dict
Returns: The dict
'''
return load_roman_bandpasses()[0]

class BaseObject(object):
'''
Expand Down
14 changes: 4 additions & 10 deletions skycatalogs/objects/diffsky_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,7 @@ def get_observer_sed_component(self, component, mjd=None, resolution=None):
class DiffskyConfigFragment(BaseConfigFragment):
def __init__(self, prov, cosmology,
area_partition=None, data_file_type=None):
self.super().__init__(prov, object_type_name='diffsky_galaxy')

self._opt_dict = {'area_partition': area_partition,
'data_file_type': data_file_type}
self._cosmology = cosmology

def make_fragment(self):
data = self.generic_create()
data['Cosmology'] = self._cosmology
return data
super().__init__(prov, object_type_name='diffsky_galaxy',
area_partition=area_partition,
data_file_type=data_file_type)
self._opt_dict['Cosmology'] = cosmology
21 changes: 8 additions & 13 deletions skycatalogs/objects/gaia_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def __len__(self):
class GaiaConfigFragment(BaseConfigFragment):
def __init__(self, prov, id_prefix=None, use_butler=False,
butler_parameters=None, area_partition=None,
data_file_type=None,
data_dir=None, basename_template=None):
'''
prov dict Provenance
Expand All @@ -411,6 +412,7 @@ def __init__(self, prov, id_prefix=None, use_butler=False,
each object
butler_parameters dict Used only if use_butler is true
area_partition dict Used only if use_butler is False
data_file_type dict Used only if use_butler is False
data_dir string Used only if use_butler is False
basename_template string Used only if use_butler is False
Expand All @@ -424,22 +426,15 @@ def __init__(self, prov, id_prefix=None, use_butler=False,
template_name = 'gaia_star_direct_template.yaml'

super().__init__(prov, object_type_name='gaia_star',
template_name=template_name)
template_name=template_name,
area_partition=area_partition,
data_file_type=data_file_type)

opt_dict = dict()
if id_prefix:
opt_dict['id_prefix'] = id_prefix
self._opt_dict['id_prefix'] = id_prefix
if use_butler:
if butler_parameters:
opt_dict['butler_parameters'] = butler_parameters
self._opt_dict['butler_parameters'] = butler_parameters
else:
if area_partition:
opt_dict['area_partition'] = area_partition
if data_dir:
opt_dict['data_dir'] = data_dir
if basename_template:
opt_dict['basename_template'] = basename_template
self._opt_dict = opt_dict

def make_fragment(self):
return self.generic_create()
self._opt_dict['basename_template'] = basename_template
10 changes: 4 additions & 6 deletions skycatalogs/objects/galaxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,13 @@ def get_instcat_entry(self, band='r', component=None):
class GalaxyConfigFragment(BaseConfigFragment):
def __init__(self, prov, cosmology, tophat_bins,
area_partition=None, data_file_type=None):
super().__init__(prov, object_type_name='galaxy')

self._opt_dict = {'area_partition': area_partition,
'data_file_type': data_file_type}
self._cosmology = cosmology
super().__init__(prov, object_type_name='galaxy',
area_partition=area_partition,
data_file_type=data_file_type)
self._opt_dict['Cosmology'] = cosmology
self._tophat_bins = tophat_bins

def make_fragment(self):
data = self.generic_create()
data['tophat']['bins'] = self._tophat_bins
data['Cosmology'] = self._cosmology
return data
10 changes: 3 additions & 7 deletions skycatalogs/objects/snana_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,6 @@ def __init__(self, ra, dec, id, object_type, partition_id, sky_catalog,

class SnanaConfigFragment(BaseConfigFragment):
def __init__(self, prov, area_partition=None, data_file_type=None):
super().__init__(prov, object_type_name='snana')

self._opt_dict = {'area_partition': area_partition,
'data_file_type': data_file_type}

def make_fragment(self):
return self.generic_create()
super().__init__(prov, object_type_name='snana',
area_partition=area_partition,
data_file_type=data_file_type)
10 changes: 3 additions & 7 deletions skycatalogs/objects/sso_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,6 @@ def __getitem__(self, key):

class SsoConfigFragment(BaseConfigFragment):
def __init__(self, prov, area_partition=None, data_file_type=None):
super().__init__(prov, object_type_name='sso')

self._opt_dict = {'area_partition': area_partition,
'data_file_type': data_file_type}

def make_fragment(self):
return self.generic_create()
super().__init__(prov, object_type_name='sso',
area_partition=area_partition,
data_file_type=data_file_type)
11 changes: 3 additions & 8 deletions skycatalogs/objects/star_object.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
# import numpy as np
import galsim
from .base_object import BaseObject
from ..utils import normalize_sed
Expand Down Expand Up @@ -38,10 +37,6 @@ def get_observer_sed_component(self, component, mjd=None):

class StarConfigFragment(BaseConfigFragment):
def __init__(self, prov, area_partition=None, data_file_type=None):
super().__init__(prov, object_type_name='star')

self._opt_dict = {'area_partition': area_partition,
'data_file_type': data_file_type}

def make_fragment(self):
return self.generic_create()
super().__init__(prov, object_type_name='star',
area_partition=area_partition,
data_file_type=data_file_type)
7 changes: 4 additions & 3 deletions skycatalogs/skyCatalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import numpy.ma as ma
from astropy import units as u
import lsst.sphgeom
from skycatalogs.objects.base_object import load_lsst_bandpasses, load_roman_bandpasses
from skycatalogs.objects.base_object import _load_lsst_bandpasses
from skycatalogs.objects.base_object import _load_roman_bandpasses
from skycatalogs.utils.catalog_utils import CatalogContext
from skycatalogs.objects.base_object import ObjectList, ObjectCollection
from skycatalogs.objects.gaia_object import GaiaObject, GaiaCollection
Expand Down Expand Up @@ -803,6 +804,6 @@ def open_catalog(config_file, mp=False, skycatalog_root=None, verbose=False):
verbose=verbose)

# Get bandpasses in case we need to compute fluxes
_, cat._lsst_thru_v = load_lsst_bandpasses()
_, cat._roman_thru_v = load_roman_bandpasses()
_, cat._lsst_thru_v = _load_lsst_bandpasses()
_, cat._roman_thru_v = _load_roman_bandpasses()
return cat
15 changes: 14 additions & 1 deletion skycatalogs/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,10 @@ def _read_yaml(inpath, silent=True, resolve_include=True):
inpath string path to file
silent boolean if file not found and silent is true,
return None. Else raise exception
resolve_include boolean if False, return values like
!include star.yaml
literally. If True, replace with content
of references file
Returns
-------
Expand Down Expand Up @@ -502,6 +506,15 @@ def write_yaml(self, input_dict, outpath):
Write yaml file if
* it doesn't already exist or
* we're allowed to overwrite
Parameters
----------
input_dict dict Contents to be output to yaml
outpath string Where to write the output
Returns
-------
output path (same as argument) if a file is written, else None
'''
if self._overwrite:
return self.update_yaml(input_dict, outpath)
Expand All @@ -512,7 +525,7 @@ def write_yaml(self, input_dict, outpath):
except FileExistsError:
txt = 'write_yaml: Will not overwrite pre-existing config'
self._logger.warning(txt + outpath)
return
return None

return outpath

Expand Down

0 comments on commit fc25144

Please sign in to comment.