Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Fixed merge errors
  • Loading branch information
Nschanche committed Apr 2, 2024
2 parents a7651d3 + 8ceb9eb commit d76a827
Showing 1 changed file with 23 additions and 128 deletions.
151 changes: 23 additions & 128 deletions src/newlk_search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,33 @@
import re
import logging
import warnings
from lightkurve.utils import (
LightkurveDeprecationWarning,
LightkurveError,
LightkurveWarning,
suppress_stdout,
)

import numpy as np
from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.table import Table, join
from astropy.table import Table
from astropy.time import Time
from lightkurve.io import read

from copy import deepcopy

# import cache
# from .config import conf, config
from . import PACKAGEDIR, PREFER_CLOUD, DOWNLOAD_CLOUD, conf, config

default_download_dir = config.get_cache_dir()

from memoization import cached

# import cache
# from src.newlk_search.cache import cache


log = logging.getLogger(__name__)


class SearchError(Exception):
pass

class SearchWarning(Warning):
pass

class MASTSearch(object):
"""
Expand Down Expand Up @@ -74,10 +69,6 @@ class MASTSearch(object):
Mission Specific Survey value that corresponds to Sector (TESS), Campaign (K2), or Quarter (Kepler)
"""

# Shared functions that are used for searches by any mission
# "mission",
# Start time?
# distance
_REPR_COLUMNS = [
"target_name",
"pipeline",
Expand All @@ -88,7 +79,6 @@ class MASTSearch(object):
"description",
]

# why is this needed here? recursion error otherwise
table = None

def __init__(
Expand All @@ -110,15 +100,14 @@ def __init__(
pipeline = np.atleast_1d(pipeline).tolist()
self.search_pipeline = pipeline
self.search_sequence = sequence

# Legacy functionality - no longer query kic/tic by integer value only
if isinstance(target, int):
raise TypeError(
"Target must be a target name string, (ra, dec) tuple, "
"or astropy coordinate object"
)

# If target is not None, Parse the input
# TODO: get rid of saving prod and obs to self
self.target = target
if isinstance(table, type(None)):
self._target_from_name(target)
Expand Down Expand Up @@ -146,7 +135,8 @@ def _target_from_name(self, target):
self.table = self.table[mask]

def _target_from_table(self, table, obs_table, prod_table):
# see if we were passed a joint table

# see if function was passed a joint table
if isinstance(table, pd.DataFrame):
self.table = table

Expand Down Expand Up @@ -239,8 +229,7 @@ def __repr__(self):
else:
return "I am an uninitialized MASTSearch result"

# This is a possible addition to add a hyperlink to the dataproduct homepages.
# I think we want this anyways as this calls the pandas table html output which is nicer
# Used to call the pandas table html output which is nicer
def _repr_html_(self):
if isinstance(self.table, pd.DataFrame):
return self.table[self._REPR_COLUMNS]._repr_html_()
Expand All @@ -258,22 +247,19 @@ def __getitem__(self, key):
return self._mask(key)

def _mask(self, mask):
"""Masks down the product and observation tables given an input mask, then returns them as a new K2Search object."""
"""Masks down the product and observation tables given an input mask, then returns them as a new Search object.
deepcopy is used to preserve the class metadata stored in class variables"""
new_table = deepcopy(self)
new_table.table = self.table[mask].reset_index()

return new_table

# may overwrite this function in the individual KEplerSearch/TESSSearch/K2Search calls?
def _update_table(self, joint_table):
# Ideally I'd like to replace of t_exptime and pro
# joint_table['exptime'] = joint_table['t_exptime'].copy()
# joint_table['pipeline'] = joint_table['provenance_name'].copy()
# joint_table['mission'] = joint_table['obs_collection_obs'].copy()
# joint_table = joint_table.rename(columns={"t_exptime":"exptime","provenance_name":"pipeline","obs_collection_obs":"mission"})
#copy columns
joint_table = joint_table.rename(columns={"t_exptime": "exptime"})
joint_table["pipeline"] = joint_table["provenance_name"].copy()
joint_table["mission"] = joint_table["obs_collection_obs"].copy()

# rename identical columns
joint_table.rename(
columns={
Expand All @@ -287,22 +273,6 @@ def _update_table(self, joint_table):
)
joint_table = joint_table.reset_index()

#year = np.floor(Time(joint_table["t_min"], format="mjd").decimalyear)
## `t_min` is incorrect for Kepler pipeline products, so we extract year from the filename for those
#for idx, row in joint_table.iterrows():
# if (row['pipeline'] == "Kepler") & ("Data Validation" not in row['description']):
# year[idx] = re.findall(
# r"\d+.(\d{4})\d+", row["productFilename"]
# )[0]
#joint_table["year"] = year.astype(int)
#
## TODO: make sure the time for TESS/Kepler/K2 all add 2400000.5
#joint_table["start_time"] = Time(
# self.table["t_min"].values + 2400000.5, format="jd"
#).iso
#joint_table["end_time"] = Time(
# self.table["t_max"].values + 2400000.5, format="jd"
#).iso
return joint_table

def _fix_table_times(self, joint_table):
Expand All @@ -324,30 +294,7 @@ def _fix_table_times(self, joint_table):
).iso

return joint_table

"""
Full list of features
['intentType', 'obscollection_obs', 'provennce_name',
'instrument_name', 'project_obs', 'filters', 'wavelength_region',
'target_name', 'target_classification', 'obs_id', 's_ra', 's_dec',
'dataproduct_type_obs', 'proposal_pi', 'calib_level_obs', 't_min',
't_max', 't_exptime', 'em_min', 'em_max', 'obs_title', 't_obs_release',
'proposal_id_obs', 'proposal_type', 'sequence_number', 's_region',
'jpegURL', 'dataURL', 'dataRights_obs', 'mtFlag', 'srcDen', 'obsid',
'objID', 'objID1', 'distance', 'obsID', 'obs_collection_prod',
'dataproduct_type_prod', 'description', 'type', 'dataURI',
'productType', 'productGroupDescription', 'productSubGroupDescription',
'productDocumentationURL', 'project_prod', 'prvversion',
'proposal_id_prod', 'productFilename', 'size', 'parent_obsid',
'dataRights_prod', 'calib_level_prod']"""

# Other additions may include the following
# self._add_columns("something")
# self._add_urls_to_authors()
# self._add_s3_url_column()
# self._sort_by_priority()



def _search(
self,
search_radius: Union[float, u.Quantity] = None,
Expand Down Expand Up @@ -414,7 +361,7 @@ def _parse_input(self, search_input):
)

def _add_s3_url_column(self, joint_table):
# self.table would updated to have an extra column of s3 URLS if possible
""" self.table will updated to have an extra column of s3 URLS if possible """
Observations.enable_cloud_dataset()
cloud_uris = Observations.get_cloud_uris(
Table.from_pandas(joint_table), full_url=True
Expand All @@ -432,29 +379,13 @@ def _search_obs(
sequence=None,
cadence=None,
):
# Helper function that returns a Search Result object containing MAST products
# combines the results of Observations.query_criteria (called via self.query_mast) and Observations.get_product_list

"""if [bool(quarter),
bool(campaign),
bool(sector)].count(True) > 1:
raise LightkurveError("Ambiguity Error; multiple quarter/campaign/sector specified."
"If searching for specific data across different missions, perform separate searches by mission.")
"""

# Is this what we want to do/ where we want the error thrown?
# Is this what we want to do/ where we want the error thrown for an ffi search in MASTsearch?
if filetype == "ffi":
raise SearchError(
f"FFI search not implemented in MASTSearch. Please use TESSSearch."
)

# if a quarter/campaign/sector is specified, search only that mission
"""if quarter is not None:
mission = ["Kepler"]
if campaign is not None:
mission = ["K2"]
if sector is not None:
mission = ["TESS"] """
# Ensure mission is a list
mission = np.atleast_1d(mission).tolist()
if pipeline is not None:
Expand Down Expand Up @@ -505,7 +436,6 @@ def _query_mast(
):
from astroquery.exceptions import NoResultsWarning, ResolverError

# **extra_query_criteria,):
# Constructs the appropriate query for mast
log.debug(f"Searching for {self.target} with {exptime} on project {project}")

Expand Down Expand Up @@ -605,11 +535,6 @@ def cubedata(self):
# return self._cubedata()
return self._mask(mask)

# def _cubedata(self):
# """ passthrough that mission searches can call """
# mask = self.table.productFilename.str.endswith("tp.fits")
# return(self._mask(mask))

def limit_results(self, limit: int):
mask = np.ones(len(self.table), dtype=bool)
mask[limit:] = False
Expand Down Expand Up @@ -726,20 +651,9 @@ def _filter(
else:
exptime_mask = not mask

"""# If no products are left, return an empty dataframe with the same columns
if sum(mask) == 0:
return pd.DataFrame(columns = products.keys())
products = products[mask]
products.sort_values(by=["distance", "productFilename"], ignore_index=True)
return products"""
# I think this hidden filter function should now just return the mask
mask = file_mask & project_mask & provenance_mask & exptime_mask
return mask

# Again, may want to add to self.mask if we go that route.
def _mask_by_exptime(self, exptime):
"""Helper function to filter by exposure time.
Returns a boolean array"""
Expand Down Expand Up @@ -798,9 +712,10 @@ def download(
Cachine more seamless if a user is searching for the same file(s) accross different project
directories and has a pipeline workflow with input functions
"""

if len(self.table) == 0:
warnings.warn(
"Cannot download from an empty search result.", LightkurveWarning
"Cannot download from an empty search result.", SearchWarning
)
return None
if cloud:
Expand Down Expand Up @@ -1047,13 +962,15 @@ def search_individual_ffi(self,
**query_criteria,
)

ffi_products = Observations.get_product_list(ffi_obs)
ffi_products = Observations.get_product_list(ffi_obs
)
#filter out uncalibrated ffi's & theoretical potential HLSP
prod_mask = ffi_products['calib_level'] == 2
ffi_products = ffi_products[prod_mask]

new_table = deepcopy(self)

# Unlike the other products, ffis don't map cleanly bia obs_id as advertised, so we invert and add specific column info
new_table.obs_table = ffi_products.to_pandas()
new_table.obs_table['year'] = np.nan

Expand Down Expand Up @@ -1086,7 +1003,8 @@ def download(self, cloud: PREFER_CLOUD = True, cache: PREFER_CLOUD = True, cloud
mask = self.table["provenance_name"] == "TESScut"
self._mask(~mask).download()
from astroquery.mast import Tesscut
Tesscut.enable_cloud_dataset()
if cloud:
Tesscut.enable_cloud_dataset()
mf1 = Tesscut.download_cutouts(coordinates=self.SkyCoord,
size=TESScut_size,
sector=self.table['sequence_number'].values[mask],
Expand All @@ -1098,24 +1016,6 @@ def download(self, cloud: PREFER_CLOUD = True, cache: PREFER_CLOUD = True, cloud
manifest = mf1.append(mf2)
return manifest




#Download should work here
# def download_ffi(self):
# raise NotImplementedError

# This was in Christina's PR to search. Is this a way we want to handle HLSPs?
# def _mask_bad_authors(authors):
# """Returns a mask to remove authors we don't have readers for."""
# bad_authors = np.asarray([author not in AUTHOR_LINKS.keys() for author in authors])
# if bad_authors.any():
# log.warn(
# f"Authors {np.unique(authors[bad_authors])} have been removed as `lightkurve` does not have a specific reader for these HLSPs.",
# )
# return ~bad_authors


class KeplerSearch(MASTSearch):
def __init__(
self,
Expand Down Expand Up @@ -1147,11 +1047,6 @@ def __init__(
# Can't search mast with quarter/month directly, so filter on that after the fact.
self.table = self.table[self._filter_kepler(quarter, month)]

"""
# Now implemented in base class
def _fix_times():
# Fixes Kepler times
raise NotImplementedError"""

def _handle_kbonus(self):
# KBONUS times are masked as they are invalid for the quarter data
Expand Down

0 comments on commit d76a827

Please sign in to comment.