Skip to content

Commit

Permalink
Merge pull request #73 from LSSTDESC/u/jrbogart/flake
Browse files Browse the repository at this point in the history
U/jrbogart/flake
  • Loading branch information
JoanneBogart authored Oct 14, 2023
2 parents 7c68b00 + c30f271 commit 66cd16d
Show file tree
Hide file tree
Showing 23 changed files with 442 additions and 407 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,6 @@ target/

# local sandbox
local/

# local flake8 running
.flake8
151 changes: 81 additions & 70 deletions skycatalogs/catalog_creator.py

Large diffs are not rendered by default.

71 changes: 36 additions & 35 deletions skycatalogs/objects/base_object.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
from collections.abc import Sequence, Iterable
from collections import namedtuple
import os
import gzip
import itertools
import numpy as np
import astropy.units as u
import warnings
from dust_extinction.parameter_averages import F19
import galsim
import logging
from galsim.errors import GalSimRangeError

from skycatalogs.utils.translate_utils import form_object_string
from skycatalogs.utils.config_utils import Config
Expand All @@ -26,6 +20,8 @@
LSST_BANDS = ('ugrizy')

# global for easy access for code run within mp


def load_lsst_bandpasses():
'''
Read in lsst bandpasses from standard place, trim, and store in global dict
Expand All @@ -40,17 +36,17 @@ def load_lsst_bandpasses():

if os.path.exists(bp_dir):
BaseObject._bp_path = bp_dir
#logger.info(f'Using rubin sim dir {rubin_sim_dir}')
# logger.info(f'Using rubin sim dir {rubin_sim_dir}')
else:
bp_dir = os.path.join(os.getenv('HOME'), 'rubin_sim_data',
'throughputs', 'baseline')
#logger.info(f'Using rubin sim dir rubin_sim_data under HOME')
'throughputs', 'baseline')
# logger.info(f'Using rubin sim dir rubin_sim_data under HOME')
if os.path.exists(bp_dir):
BaseObject._bp_path = bp_dir
else:
#logger.info('Using galsim built-in bandpasses')
# logger.info('Using galsim built-in bandpasses')
bp_dir = None
fname_fmt = 'LSST_{band}.dat'

for band in LSST_BANDS:
if bp_dir:
bp_full_path = os.path.join(bp_dir, f'total_{band}.dat')
Expand All @@ -60,22 +56,25 @@ def load_lsst_bandpasses():

# Mirror behavior in imsim.RubinBandpass:
# https://github.com/LSSTDESC/imSim/blob/main/imsim/bandpass.py#L9
# Trim the edges to avoid 1.e-4 values out to very high and low wavelengths.
# Trim the edges to avoid 1.e-4 values out to very high and low
# wavelengths.
bp = bp.truncate(relative_throughput=1.e-3)
# Remove wavelength values selectively for improved speed but preserve flux integrals.
# Remove wavelength values selectively for improved speed but
# preserve flux integrals.
bp = bp.thin()
bp = bp.withZeropoint('AB')
lsst_bandpasses[band] = bp

return lsst_bandpasses


class BaseObject(object):
'''
Abstract base class for static (in position coordinates) objects.
Likely need a variant for SSO.
'''

_bp500 = galsim.Bandpass(galsim.LookupTable([499, 500, 501],[0, 1, 0]),
_bp500 = galsim.Bandpass(galsim.LookupTable([499, 500, 501], [0, 1, 0]),
wave_type='nm').withZeropoint('AB')

def __init__(self, ra, dec, id, object_type, belongs_to, belongs_index,
Expand Down Expand Up @@ -103,7 +102,6 @@ def __init__(self, ra, dec, id, object_type, belongs_to, belongs_index,
# All objects also include redshift information. Also MW extinction,
# but extinction is by subcomponent for galaxies


@property
def ra(self):
return self._ra
Expand All @@ -122,7 +120,8 @@ def object_type(self):

@property
def redshift(self):
if self._redshift: return self._redshift
if self._redshift:
return self._redshift
if self._belongs_to:
self._redshift = self.get_native_attribute('redshift')
return self._redshift
Expand Down Expand Up @@ -173,7 +172,7 @@ def get_sed_metadata(self, **kwargs):
'''
raise NotImplementedError

def get_instcat_entry(self, band = 'r', component=None):
def get_instcat_entry(self, band='r', component=None):
'''
Return the string corresponding to instance catalog line
Parameters:
Expand Down Expand Up @@ -204,7 +203,6 @@ def _get_sed(self, component=None, resolution=None, mjd=None):

raise NotImplementedError('Must be implemented by BaseObject subclass if needed')


def write_sed(self, sed_file_path, component=None, resolution=None,
mjd=None):
sed, _ = self._get_sed(component=component, resolution=None, mjd=None)
Expand Down Expand Up @@ -284,11 +282,11 @@ def get_total_observer_sed(self, mjd=None):
components.
"""
sed = None
for sed_component in self.get_observer_sed_components(mjd=mjd).values():
for sed_cmp in self.get_observer_sed_components(mjd=mjd).values():
if sed is None:
sed = sed_component
sed = sed_cmp
else:
sed += sed_component
sed += sed_cmp

return sed

Expand Down Expand Up @@ -318,7 +316,7 @@ def get_fluxes(self, bandpasses, mjd=None):
return [sed.calculateFlux(b) for b in bandpasses]

def get_LSST_flux(self, band, sed=None, cache=True, mjd=None):
if not band in LSST_BANDS:
if band not in LSST_BANDS:
return None
att = f'lsst_flux_{band}'

Expand Down Expand Up @@ -463,7 +461,8 @@ def get_native_attribute(self, attribute_name):
exists.
'''
val = getattr(self, attribute_name, None)
if val is not None: return val
if val is not None:
return val

for r in self._rdrs:
if attribute_name in r.columns:
Expand Down Expand Up @@ -510,7 +509,7 @@ def get_native_attributes_iterator(self, attribute_names):
-------
iterator which returns df for a chunk of values of the attributes
'''
pass # for now
pass # for now

# implement Sequence methods
def __contains__(self, obj):
Expand All @@ -519,7 +518,7 @@ def __contains__(self, obj):
----------
obj can be an (object id) or of type BaseObject
'''
if type(obj) == type(10):
if isinstance(obj, int):
id = obj
else:
if isinstance(obj, BaseObject):
Expand All @@ -531,8 +530,8 @@ def __contains__(self, obj):
def __len__(self):
return len(self._id)

#def __iter__(self): Standard impl based on __getitem__ should be ok
#def __reversed__(self): Default implementation ok
# def __iter__(self): Standard impl based on __getitem__ should be ok
# def __reversed__(self): Default implementation ok

def __getitem__(self, key):
'''
Expand All @@ -556,7 +555,7 @@ def __getitem__(self, key):
elif type(key) == slice:
if key.start is None:
key.start = 0
ixdata = [i for i in range(min(key.stop,len(self._ra)))]
ixdata = [i for i in range(min(key.stop, len(self._ra)))]
ixes = itertools.islice(ixdata, key.start, key.stop, key.step)
return [self._object_class(self._ra[i], self._dec[i], self._id[i],
object_type, self, i)
Expand All @@ -568,15 +567,15 @@ def __getitem__(self, key):
object_type, self, i)
for i in key[0]]


def get_partition_id(self):
return self._partition_id

def count(self, obj):
'''
returns # of occurrences of obj. It can only be 0 or 1
'''
if self.__contains__(obj): return 1
if self.__contains__(obj):
return 1
return 0

def index(self, obj):
Expand All @@ -585,6 +584,7 @@ def index(self, obj):
'''
return self._id.index(obj.id)


LocatedCollection = namedtuple('LocatedCollection',
['collection', 'first_index', 'upper_bound'])
'''
Expand All @@ -595,6 +595,7 @@ def index(self, obj):
so upper_bound for one collection = first_index in the next
'''


class ObjectList(Sequence):
'''
Keep track of a list of ObjectCollection objects, but from user
Expand Down Expand Up @@ -642,13 +643,13 @@ def get_collections(self):
return constituent ObjectCollection objects in a list
'''
collections = []
for e in self._located:
for e in self._located:
collections.append(e.collection)

return collections


# implement Sequence methods

def __contains__(self, obj):
'''
Parameters
Expand All @@ -665,8 +666,8 @@ def __contains__(self, obj):
def __len__(self):
return self._total_len

#def __iter__(self): Standard impl based on __getitem__ should be ok??
#def __reversed__(self): Default implementation ok??
# def __iter__(self): Standard impl based on __getitem__ should be ok??
# def __reversed__(self): Default implementation ok??

def __getitem__(self, key):
'''
Expand Down Expand Up @@ -709,7 +710,7 @@ def __getitem__(self, key):
sub = [elem - e.first_index for elem in key_list if elem >= e.first_index and elem < e.upper_bound]
to_return += e.collection[(sub,)]

start_ix +=len(sub)
start_ix += len(sub)
if start_ix >= len(key_list):
break
start = key_list[start_ix]
Expand Down
21 changes: 14 additions & 7 deletions skycatalogs/objects/gaia_object.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import sys
import warnings
from functools import wraps
import itertools
from collections.abc import Iterable
from pathlib import PurePath
import numpy as np
import erfa
Expand Down Expand Up @@ -60,6 +60,7 @@ class GaiaObject(BaseObject):
_stellar_temperature = _TEMP_FUNC
_gaia_bp_bandpass = _GAIA_BP
_wavelengths = np.arange(250, 1250, 5, dtype=float)

def __init__(self, obj_pars, parent_collection, index):
"""
Parameters
Expand Down Expand Up @@ -92,9 +93,9 @@ def __init__(self, obj_pars, parent_collection, index):
# Convert from flux units of nJy to AB mag for the bp passband,
# which we will use to normalize the SED.
self.bp_mag = -2.5*np.log10(bp_flux*1e-9) + 8.90
except galsim.errors.GalSimRangeError as ex:
except galsim.errors.GalSimRangeError:
self.stellar_temp = None
except RuntimeError as rex:
except RuntimeError:
self.stellar_temp = None

def blambda(self, wl):
Expand Down Expand Up @@ -126,16 +127,21 @@ def get_gsobject_components(self, gsparams=None, rng=None):
def set_use_lut(self, use_lut):
self.use_lut = use_lut


class GaiaCollection(ObjectCollection):
# Class methods
_gaia_config = None
def set_config(config):

@classmethod
def set_config(cls, config):
GaiaCollection._gaia_config = config

def get_config():
@classmethod
def get_config(cls):
return GaiaCollection._gaia_config

@ignore_erfa_warnings
@staticmethod
def load_collection(region, skycatalog, mjd=None):
if isinstance(region, Disk):
ra = lsst.geom.Angle(region.ra, lsst.geom.degrees)
Expand Down Expand Up @@ -229,13 +235,14 @@ def __getitem__(self, key):
return GaiaObject(self.df.iloc[key], self, key)

elif type(key) == slice:
ixdata = [i for i in range(min(key.stop,len(self._id)))]
ixdata = [i for i in range(min(key.stop, len(self._id)))]
ixes = itertools.islice(ixdata, key.start, key.stop, key.step)
return [self._object_class(self.df.iloc[i], self, i) for i in ixes]

elif type(key) == tuple and isinstance(key[0], Iterable):
# check it's a list of int-like?
return [self._object_class(self.df.iloc[i], self, i) for i in key[0]]
return [self._object_class(self.df.iloc[i], self,
i) for i in key[0]]

def __len__(self):
return len(self.df)
12 changes: 7 additions & 5 deletions skycatalogs/objects/galaxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import galsim

from .base_object import BaseObject
from skycatalogs.utils.translate_utils import form_object_string

__all__ = ['GalaxyObject']


class GalaxyObject(BaseObject):
_type_name = 'galaxy'

def _get_sed(self, component=None, resolution=None):
'''
Return sed and mag_norm for a galaxy component or for a star
Expand All @@ -25,7 +28,7 @@ def _get_sed(self, component=None, resolution=None):
raise ValueError(f'Cannot fetch SED for component type {component}')

th_val = self.get_native_attribute(f'sed_val_{component}')
if th_val is None: # values for this component are not in the file
if th_val is None: # values for this component are not in the file
raise ValueError(f'{component} not part of this catalog')

# if values are all zeros or nearly no point in trying to convert
Expand All @@ -46,11 +49,11 @@ def get_wl_params(self):
"""Return the weak lensing parameters, g1, g2, mu."""
gamma1 = self.get_native_attribute('shear_1')
gamma2 = self.get_native_attribute('shear_2')
kappa = self.get_native_attribute('convergence')
kappa = self.get_native_attribute('convergence')
# Compute reduced shears and magnification.
g1 = gamma1/(1. - kappa) # real part of reduced shear
g2 = gamma2/(1. - kappa) # imaginary part of reduced shear
mu = 1./((1. - kappa)**2 - (gamma1**2 + gamma2**2)) # magnification
mu = 1./((1. - kappa)**2 - (gamma1**2 + gamma2**2)) # magnification
return g1, g2, mu

def get_total_observer_sed(self, mjd=None):
Expand All @@ -67,7 +70,6 @@ def get_total_observer_sed(self, mjd=None):
sed *= mu
return sed


def get_gsobject_components(self, gsparams=None, rng=None):

if gsparams is not None:
Expand Down Expand Up @@ -123,7 +125,7 @@ def get_observer_sed_component(self, component, mjd=None, resolution=None):

return sed

def get_instcat_entry(self, band = 'r', component=None):
def get_instcat_entry(self, band='r', component=None):
'''
Return the string corresponding to instance catalog line
Parameters:
Expand Down
Loading

0 comments on commit 66cd16d

Please sign in to comment.