From 458c7b08a06d1af2efb0b45f7d9fddb61be2a1df Mon Sep 17 00:00:00 2001 From: Kelle Cruz Date: Tue, 16 Apr 2024 17:23:13 -0400 Subject: [PATCH] Add ingest_radial_velocity function. (#475) * new ingest_radial_velocities function and test * Gaia DR3 ingest script * move astrometry tests from test_utils to test_astrometry * doc update units and column names * ingest rv function * tests, better but not passing [skip ci] * new test [skip ci] * package name change to astrodb_utils * use conftest temp_db * astrodb_utils name change and updates * astrodb_utils name change * scripts->utils and temp_db * path to simple * ignore userwarning * revert irrelevant changes * fixes needed for tests * fixes for tests. closes Update regex expressions #486 * add units to rv column names * tidying --- documentation/RadialVelocities.md | 4 +- scripts/ingests/Gaia/ingest_gaiadr3.py | 89 ++++++++++++ simple/schema.py | 4 +- simple/utils/astrometry.py | 188 ++++++++++++++++++++++++- simple/utils/companions.py | 2 +- simple/utils/spectral_types.py | 6 +- tests/conftest.py | 5 +- tests/test_astrometry.py | 167 ++++++++++++++++++++++ tests/test_integrity.py | 2 +- tests/test_utils.py | 70 +-------- 10 files changed, 454 insertions(+), 83 deletions(-) create mode 100644 scripts/ingests/Gaia/ingest_gaiadr3.py create mode 100644 tests/test_astrometry.py diff --git a/documentation/RadialVelocities.md b/documentation/RadialVelocities.md index eb70b0c1f..2b2741d66 100644 --- a/documentation/RadialVelocities.md +++ b/documentation/RadialVelocities.md @@ -7,8 +7,8 @@ Columns marked with an asterisk (*) may not be empty. | Column Name | Description | Unit | Data Type | Key Type | |---|---|---|---|---| | *source | Unique identifier for the source | | String(100) | primary and foreign: Sources.source | -| *radial_velocity | Radial velocity | km/yr | Float | | -| radial_velocity_error | Uncertainty of radial velocity | km/yr | Float | | +| *radial_velocity_km_s | Radial velocity | km/s | Float | | +| radial_velocity_error_km_s | Uncertainty of radial velocity | km/s | Float | | | adopted | Flag indicating if this is the adopted measurement | | Boolean | | | comments | Free form comments | | String(1000) | | | *reference | Reference | | String(30) | primary and foreign: Publications.name | diff --git a/scripts/ingests/Gaia/ingest_gaiadr3.py b/scripts/ingests/Gaia/ingest_gaiadr3.py new file mode 100644 index 000000000..c197059e8 --- /dev/null +++ b/scripts/ingests/Gaia/ingest_gaiadr3.py @@ -0,0 +1,89 @@ +from scripts.ingests.ingest_utils import * +from scripts.ingests.utils import * +from astroquery.gaia import Gaia +from astropy.table import Table, setdiff +from astropy import table +from sqlalchemy import func +import numpy as np +import pandas as pd + +# GLOBAL VARIABLES + +SAVE_DB = True # save the data files in addition to modifying the .db file +RECREATE_DB = True # recreates the .db file from the data files +VERBOSE = False +DATE_SUFFIX = "Jun2022" +# LOAD THE DATABASE +db = load_simpledb("SIMPLE.db", recreatedb=RECREATE_DB) + +logger.setLevel(logging.DEBUG) + + +# Functions +# Querying GaiaDR3 +def query_gaia_dr3(input_table): + print("Gaia DR3 query started") + gaia_query_string = ( + "SELECT *,upload_table.db_names FROM gaiadr3.gaia_source " + "INNER JOIN tap_upload.upload_table ON " + "gaiadr3.gaia_source.source_id = tap_upload.upload_table.dr3_source_id " + ) + job_gaia_query = Gaia.launch_job( + gaia_query_string, + upload_resource=input_table, + upload_table_name="upload_table", + verbose=VERBOSE, + ) + + gaia_data = job_gaia_query.get_results() + + print("Gaia DR3 query complete") + + return gaia_data + + +# Ingesting the GAIADR3 publication +def update_ref_tables(): + ingest_publication( + db, + doi="10.1051/0004-6361/202243940", + publication="GaiaDR3", + description="Gaia Data Release 3.Summary of the content and survey properties", + ignore_ads=True, + ) + + +# update_ref_tables() + + +def add_gaia_rvs(data, ref): + unmasked_rvs = np.logical_not(data["radial_velocity"].mask).nonzero() + rvs = data[unmasked_rvs]["db_names", "radial_velocity", "radial_velocity_error"] + refs = [ref] * len(rvs) + ingest_radial_velocities( + db, rvs["db_names"], rvs["radial_velocity"], rvs["radial_velocity_error"], refs + ) + return + + +dr3_desig_file_string = ( + "scripts/ingests/Gaia/gaia_dr3_designations_" + "Sep2021" + ".xml" +) +gaia_dr3_names = Table.read(dr3_desig_file_string, format="votable") +pd_gaia_dr3_names = gaia_dr3_names.to_pandas + +# Querying the GAIA DR3 Data +# gaia_dr3_data = query_gaia_dr3(gaia_dr3_names) + +# making the data file and then converting the string into an astropy table +dr3_data_file_string = "scripts/ingests/Gaia/gaia_dr3_data_" + DATE_SUFFIX + ".xml" +# gaia_dr3_data.write(dr3_data_file_string, format='votable') +gaia_dr3_data = Table.read(dr3_data_file_string, format="votable") + +# ingest_sources(db, gaia_dr3_data['designation'], 'GaiaDR3') + +add_gaia_rvs(gaia_dr3_data, "GaiaDR3") + +# WRITE THE JSON FILES +if SAVE_DB: + db.save_database(directory="data/") diff --git a/simple/schema.py b/simple/schema.py index 6cdf1a181..085fa6237 100644 --- a/simple/schema.py +++ b/simple/schema.py @@ -234,8 +234,8 @@ class RadialVelocities(Base): nullable=False, primary_key=True, ) - radial_velocity = Column(Float, nullable=False) - radial_velocity_error = Column(Float) + radial_velocity_km_s = Column(Float, nullable=False) + radial_velocity_error_km_s = Column(Float) adopted = Column(Boolean) comments = Column(String(1000)) reference = Column( diff --git a/simple/utils/astrometry.py b/simple/utils/astrometry.py index 6557c2c59..7f3394a55 100644 --- a/simple/utils/astrometry.py +++ b/simple/utils/astrometry.py @@ -1,11 +1,11 @@ -from astropy.table import Table import logging +from typing import Optional, Union from sqlalchemy import and_ import sqlalchemy.exc -from astrodb_utils import ( - AstroDBError, - find_source_in_db, -) +from astropy.units import Quantity +from astropy.table import Table +from astrodbkit2.astrodb import Database +from astrodb_utils import AstroDBError, find_source_in_db, find_publication __all__ = [ @@ -357,3 +357,181 @@ def ingest_proper_motions( updated_source_pm_data.pprint_all() return + + +def ingest_radial_velocity( + db: Database, + *, + source: str = None, + rv: Union[Quantity, float] = None, + rv_err: Optional[Union[float, Quantity]] = None, + reference: str = None, + raise_error: bool = True, +): + """ + + Parameters + ---------- + db: astrodbkit2.astrodb.Database + Database object + source: str + source name + rv: float or str + radial velocity of the sources + if not a Quantity, assumed to be in km/s + rv_err: float or str + radial velocity uncertainty + if not a Quantity, assumed to be in km/s + reference: str + reference for the radial velocity data + raise_error: bool + If True, raise errors. If False, log an error and return. + + Returns + ------- + flags: dict + 'added' : bool + 'skipped' : bool + + Examples + ---------- + > ingest_radial_velocity(db, my_source, rv=my_rv, rv_err=my_rv_unc, + reference=my_rv_ref, + raise_error = False) + + """ + + flags = {"added": False, "skipped": False} + + # Find the source in the database, make sure there's only one match + db_name = find_source_in_db(db, source) + if len(db_name) != 1: + msg = f"No unique source match for {source} in the database" + flags["skipped"] = True + logger.error(msg) + if raise_error: + raise AstroDBError(msg) + else: + return flags + else: + db_name = db_name[0] + + # Make sure the publication is in the database + pub_check = find_publication(db, reference=reference) + if pub_check[0]: + msg = f"Reference found: {pub_check[1]}." + logger.info(msg) + if not pub_check[0]: + flags["skipped"] = True + msg = f"Reference {reference} not found in Publications table." + if raise_error: + logger.error(msg) + raise AstroDBError(msg) + else: + logger.warning(msg) + return flags + + # Search for existing radial velocity data and determine if this is the best + # If no previous measurement exists, set the new one to the Adopted measurement + adopted = None + source_rv_data: Table = ( + db.query(db.RadialVelocities) + .filter(db.RadialVelocities.c.source == db_name) + .table() + ) + + if source_rv_data is None or len(source_rv_data) == 0: + # if there's no other measurements in the database, set new data Adopted = True + adopted = True + logger.debug("No other measurement") + elif len(source_rv_data) > 0: # Radial Velocity data already exists + # check for duplicate measurement + dupe_ind = source_rv_data["reference"] == reference + if sum(dupe_ind): + msg = f"Duplicate radial velocity measurement\n, {source_rv_data[dupe_ind]}" + logger.warning(msg) + flags["skipped"] = True + if raise_error: + raise AstroDBError(msg) + else: + return flags + else: + msg = "!!! Another Radial Velocity measurement exists," + logger.warning(msg) + if logger.level == 10: + source_rv_data.pprint_all() + + # check for previous adopted measurement and find new adopted + adopted_ind = source_rv_data["adopted"] == 1 + if sum(adopted_ind): + old_adopted = source_rv_data[adopted_ind] + # if errors of new data are less than other measurements, + # set Adopted = True. + if rv_err < min(source_rv_data["radial_velocity_error"]): + adopted = True + + # unset old adopted + if old_adopted: + db.RadialVelocities.update().where( + and_( + db.RadialVelocities.c.source + == old_adopted["source"][0], + db.RadialVelocities.c.reference + == old_adopted["reference"][0], + ) + ).values(adopted=False).execute() + # check that adopted flag is successfully changed + old_adopted_data = ( + db.query(db.RadialVelocities) + .filter( + and_( + db.RadialVelocities.c.source + == old_adopted["source"][0], + db.RadialVelocities.c.reference + == old_adopted["reference"][0], + ) + ) + .table() + ) + logger.debug("Old adopted measurement unset") + if logger.level == 10: + old_adopted_data.pprint_all() + + logger.debug(f"The new measurement's adopted flag is:, {adopted}") + else: + msg = "Unexpected state" + logger.error(msg) + raise RuntimeError(msg) + + # Construct data to be added + radial_velocity_data = [ + { + "source": db_name, + "radial_velocity_km_s": rv, + "radial_velocity_error_km_s": rv_err, + "reference": reference, + "adopted": adopted, + } + ] + + logger.debug(f"{radial_velocity_data}") + + try: + with db.engine.connect() as conn: + conn.execute(db.RadialVelocities.insert().values(radial_velocity_data)) + conn.commit() + flags["added"] = True + msg = f"Radial Velocity added to database: \n {radial_velocity_data}" + logger.debug(msg) + except sqlalchemy.exc.IntegrityError: + flags["skipped"] = True + msg = ( + "The source may not exist in Sources table.\n" + "The Radial Velocity reference may not exist in Publications table. " + "Add it with add_publication function. \n" + "The radial velocity measurement may be a duplicate." + ) + logger.error(msg) + raise AstroDBError(msg) + + return flags diff --git a/simple/utils/companions.py b/simple/utils/companions.py index e727492e1..e2ed4ae82 100644 --- a/simple/utils/companions.py +++ b/simple/utils/companions.py @@ -150,7 +150,7 @@ def ingest_companion_relationships( else: msg = ( - "Make sure all require parameters are provided. \\" + "Make sure all required parameters are provided. \\" "Other possible errors: source may not exist in Sources table \\" "or the reference may not exist in the Publications table. " ) diff --git a/simple/utils/spectral_types.py b/simple/utils/spectral_types.py index 71df34e08..2dfb8944b 100644 --- a/simple/utils/spectral_types.py +++ b/simple/utils/spectral_types.py @@ -214,7 +214,7 @@ def ingest_spectral_types( .count() == 0 ): - msg = f"The publication {references[i]} does not exist in the database" + msg = f"The publication does not exist in the database: {references[i]}" msg1 = "Add it with ingest_publication function." logger.debug(msg + msg1) raise AstroDBError(msg) @@ -272,10 +272,10 @@ def convert_spt_string_to_code(spectral_types): else: # only trigger if not MLTY i = 0 # find integer or decimal subclass and add to spt_code - if re.search("\d*\.?\d+", spt[i + 1 :]) is None: + if re.search(r"\d*\.?\d+", spt[i + 1 :]) is None: spt_code = spt_code else: - spt_code += float(re.findall("\d*\.?\d+", spt[i + 1 :])[0]) + spt_code += float(re.findall(r"\d*\.?\d+", spt[i + 1 :])[0]) spectral_type_codes.append(spt_code) return spectral_type_codes diff --git a/tests/conftest.py b/tests/conftest.py index aca111072..7501c170f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,10 @@ import pytest import os -import sys import logging from astrodbkit2.astrodb import create_database, Database -sys.path.append("./") +import sys + +sys.path.append("./") # needed for github actions to find the simple module from simple.schema import REFERENCE_TABLES from simple.schema import * diff --git a/tests/test_astrometry.py b/tests/test_astrometry.py new file mode 100644 index 000000000..ee0f15136 --- /dev/null +++ b/tests/test_astrometry.py @@ -0,0 +1,167 @@ +import pytest +from astropy.table import Table +from astrodb_utils import AstroDBError +from simple.utils.astrometry import ( + ingest_parallaxes, + ingest_proper_motions, + ingest_radial_velocity, +) + + +# Create fake astropy Table of data to load +@pytest.fixture(scope="module") +def t_plx(): + t_plx = Table( + [ + {"source": "Fake 1", "plx": 113.0, "plx_err": 0.3, "plx_ref": "Ref 1"}, + {"source": "Fake 2", "plx": 145.0, "plx_err": 0.5, "plx_ref": "Ref 1"}, + {"source": "Fake 3", "plx": 155.0, "plx_err": 0.6, "plx_ref": "Ref 2"}, + ] + ) + return t_plx + + +@pytest.fixture(scope="module") +def t_pm(): + t_pm = Table( + [ + { + "source": "Fake 1", + "mu_ra": 113.0, + "mu_ra_err": 0.3, + "mu_dec": 113.0, + "mu_dec_err": 0.3, + "reference": "Ref 1", + }, + { + "source": "Fake 2", + "mu_ra": 145.0, + "mu_ra_err": 0.5, + "mu_dec": 113.0, + "mu_dec_err": 0.3, + "reference": "Ref 1", + }, + { + "source": "Fake 3", + "mu_ra": 55.0, + "mu_ra_err": 0.23, + "mu_dec": 113.0, + "mu_dec_err": 0.3, + "reference": "Ref 2", + }, + ] + ) + return t_pm + + +@pytest.fixture(scope="module") +def t_rv(): + t_rv = Table( + [ + {"source": "Fake 1", "rv": 113.0, "rv_err": 0.3, "rv_ref": "Ref 1"}, + {"source": "Fake 2", "rv": 145.0, "rv_err": 0.5, "rv_ref": "Ref 1"}, + {"source": "Fake 3", "rv": "155.0", "rv_err": "0.6", "rv_ref": "Ref 2"}, + ] + ) + return t_rv + + +def test_ingest_parallaxes(temp_db, t_plx): + # Test ingest of parallax data + ingest_parallaxes( + temp_db, t_plx["source"], t_plx["plx"], t_plx["plx_err"], t_plx["plx_ref"] + ) + + results = ( + temp_db.query(temp_db.Parallaxes) + .filter(temp_db.Parallaxes.c.reference == "Ref 1") + .table() + ) + assert len(results) == 2 + results = ( + temp_db.query(temp_db.Parallaxes) + .filter(temp_db.Parallaxes.c.reference == "Ref 2") + .table() + ) + assert len(results) == 1 + assert results["source"][0] == "Fake 3" + assert results["parallax"][0] == 155 + assert results["parallax_error"][0] == 0.6 + + +def test_ingest_proper_motions(temp_db, t_pm): + ingest_proper_motions( + temp_db, + t_pm["source"], + t_pm["mu_ra"], + t_pm["mu_ra_err"], + t_pm["mu_dec"], + t_pm["mu_dec_err"], + t_pm["reference"], + ) + assert ( + temp_db.query(temp_db.ProperMotions) + .filter(temp_db.ProperMotions.c.reference == "Ref 1") + .count() + == 2 + ) + results = ( + temp_db.query(temp_db.ProperMotions) + .filter(temp_db.ProperMotions.c.reference == "Ref 2") + .table() + ) + assert len(results) == 1 + assert results["source"][0] == "Fake 3" + assert results["mu_ra"][0] == 55 + assert results["mu_ra_error"][0] == 0.23 + + +def test_ingest_radial_velocities_works(temp_db, t_rv): + for ind in range(3): + ingest_radial_velocity( + temp_db, + source=t_rv["source"][ind], + rv=t_rv["rv"][ind], + rv_err=t_rv["rv_err"][ind], + reference=t_rv["rv_ref"][ind], + ) + + results = ( + temp_db.query(temp_db.RadialVelocities) + .filter(temp_db.RadialVelocities.c.reference == "Ref 1") + .table() + ) + assert len(results) == 2 + results = ( + temp_db.query(temp_db.RadialVelocities) + .filter(temp_db.RadialVelocities.c.reference == "Ref 2") + .table() + ) + assert len(results) == 1 + assert results["source"][0] == "Fake 3" + assert results["radial_velocity_km_s"][0] == 155 + assert results["radial_velocity_error_km_s"][0] == 0.6 + + +@pytest.mark.filterwarnings("ignore::UserWarning") +def test_ingest_radial_velocities_errors(temp_db): + with pytest.raises(AstroDBError) as error_message: + ingest_radial_velocity( + temp_db, source="not a source", rv=12.5, rv_err=0.5, reference="Ref 1" + ) + assert "No unique source match" in str(error_message.value) + # flag['skipped'] = True, flag['added'] = False + + with pytest.raises(AstroDBError) as error_message: + ingest_radial_velocity( + temp_db, source="Fake 1", rv=12.5, rv_err=0.5, reference="Ref 1" + ) + assert "Duplicate radial velocity measurement" in str(error_message.value) + # flag['skipped'] = True, flag['added'] = False + + with pytest.raises(AstroDBError) as error_message: + ingest_radial_velocity( + temp_db, source="Fake 1", rv=12.5, rv_err=0.5, reference="not a ref" + ) + assert "not found in Publications table" in str(error_message.value) + # flag['skipped'] = True, flag['added'] = False diff --git a/tests/test_integrity.py b/tests/test_integrity.py index de44d09bb..0d8e2e39a 100644 --- a/tests/test_integrity.py +++ b/tests/test_integrity.py @@ -378,7 +378,7 @@ def test_radialvelocities(db): # There should be no entries in the RadialVelocities table without rv values t = ( db.query(db.RadialVelocities.c.source) - .filter(db.RadialVelocities.c.radial_velocity.is_(None)) + .filter(db.RadialVelocities.c.radial_velocity_km_s.is_(None)) .astropy() ) if len(t) > 0: diff --git a/tests/test_utils.py b/tests/test_utils.py index 22c492954..87c8d8cd7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,17 +1,13 @@ import pytest -import sys from astropy.table import Table from astrodb_utils.utils import ( AstroDBError, ) - -sys.path.append("./") from simple.utils.spectral_types import ( convert_spt_string_to_code, ingest_spectral_types, ) from simple.utils.companions import ingest_companion_relationships -from simple.utils.astrometry import ingest_parallaxes, ingest_proper_motions # Create fake astropy Table of data to load @@ -68,56 +64,6 @@ def test_convert_spt_string_to_code(): assert convert_spt_string_to_code(["Y2pec"]) == [92] -def test_ingest_parallaxes(temp_db, t_plx): - # Test ingest of parallax data - ingest_parallaxes( - temp_db, t_plx["source"], t_plx["plx"], t_plx["plx_err"], t_plx["plx_ref"] - ) - - results = ( - temp_db.query(temp_db.Parallaxes) - .filter(temp_db.Parallaxes.c.reference == "Ref 1") - .table() - ) - assert len(results) == 2 - results = ( - temp_db.query(temp_db.Parallaxes) - .filter(temp_db.Parallaxes.c.reference == "Ref 2") - .table() - ) - assert len(results) == 1 - assert results["source"][0] == "Fake 3" - assert results["parallax"][0] == 155 - assert results["parallax_error"][0] == 0.6 - - -def test_ingest_proper_motions(temp_db, t_pm): - ingest_proper_motions( - temp_db, - t_pm["source"], - t_pm["mu_ra"], - t_pm["mu_ra_err"], - t_pm["mu_dec"], - t_pm["mu_dec_err"], - t_pm["reference"], - ) - assert ( - temp_db.query(temp_db.ProperMotions) - .filter(temp_db.ProperMotions.c.reference == "Ref 1") - .count() - == 2 - ) - results = ( - temp_db.query(temp_db.ProperMotions) - .filter(temp_db.ProperMotions.c.reference == "Ref 2") - .table() - ) - assert len(results) == 1 - assert results["source"][0] == "Fake 3" - assert results["mu_ra"][0] == 55 - assert results["mu_ra_error"][0] == 0.23 - - def test_ingest_spectral_types(temp_db): data1 = Table( [ @@ -142,14 +88,6 @@ def test_ingest_spectral_types(temp_db): ] ) - # data2 = Table( - # [ - # {"source": "Fake 1", "spectral_type": "M5.6", "reference": "Ref 1"}, - # {"source": "Fake 2", "spectral_type": "T0.1", "reference": "Ref 1"}, - # {"source": "Fake 3", "spectral_type": "Y2pec", "reference": "Ref 2"}, - # ] - # ) - data3 = Table( [ { @@ -200,12 +138,10 @@ def test_ingest_spectral_types(temp_db): temp_db, data3["source"], data3["spectral_type"], - data3["regime"], data3["reference"], + data3["regime"], ) - assert "The publication does not exist in the database" in str( - error_message.value - ) + assert "The publication does not exist in the database" in str(error_message.value) def test_companion_relationships(temp_db): @@ -213,7 +149,7 @@ def test_companion_relationships(temp_db): # trying no companion with pytest.raises(AstroDBError) as error_message: ingest_companion_relationships(temp_db, "Fake 1", None, "Sibling") - assert "Make sure all require parameters are provided." in str(error_message.value) + assert "Make sure all required parameters are provided." in str(error_message.value) # trying companion == source with pytest.raises(AstroDBError) as error_message: