Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use check_spectrum_plottable from astrodb_utils #548

Merged
merged 2 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 3 additions & 80 deletions simple/utils/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,22 @@
import sqlite3
from typing import Optional

import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
import requests
import sqlalchemy.exc
from astrodb_utils import (
AstroDBError,
find_source_in_db,
internet_connection,
)
from astrodb_utils.spectra import check_spectrum_plottable
from astrodbkit2.astrodb import Database
from astropy.io import fits
from specutils import Spectrum1D

from simple.schema import Spectra

__all__ = [
"ingest_spectrum",
"ingest_spectrum_from_fits",
"spectrum_plottable",
"find_spectra",
]

Expand Down Expand Up @@ -162,7 +158,7 @@ def ingest_spectrum(
return flags

# Check if spectrum is plottable
flags["plottable"] = spectrum_plottable(spectrum, raise_error=raise_error)
flags["plottable"] = check_spectrum_plottable(spectrum, raise_error=raise_error)

# Compile fields into a dictionary
row_data = {
Expand All @@ -184,7 +180,7 @@ def ingest_spectrum(

try:
# Attempt to add spectrum to database
# This will throw errors based on validation in schema.py
# This will throw errors based on validation in schema.py
# and any database checks (as for example IntegrityError)
obj = Spectra(**row_data)
with db.session as session:
Expand Down Expand Up @@ -265,79 +261,6 @@ def ingest_spectrum_from_fits(db, source, spectrum_fits_file):
)


def spectrum_plottable(spectrum_path, raise_error=True, show_plot=False):
"""
Check if spectrum is plottable
"""
# load the spectrum and make sure it's a Spectrum1D object

try:
# spectrum: Spectrum1D = load_spectrum(spectrum_path) #astrodbkit2 method
spectrum = Spectrum1D.read(spectrum_path)
except Exception as e:
msg = (
str(e) + f"\nSkipping {spectrum_path}: \n"
"unable to load file as Spectrum1D object"
)
if raise_error:
logger.error(msg)
raise AstroDBError(msg)
else:
logger.warning(msg)
return False

# checking spectrum has good units and not only NaNs
try:
wave: np.ndarray = spectrum.spectral_axis.to(u.micron).value
flux: np.ndarray = spectrum.flux.value
except AttributeError as e:
msg = str(e) + f"Skipping {spectrum_path}: unable to parse spectral axis"
if raise_error:
logger.error(msg)
raise AstroDBError(msg)
else:
logger.warning(msg)
return False
except u.UnitConversionError as e:
msg = (
f"{e} \n"
f"Skipping {spectrum_path}: unable to convert spectral axis to microns"
)
if raise_error:
logger.error(msg)
raise AstroDBError(msg)
else:
logger.warning(msg)
return False
except ValueError as e:
msg = f"{e} \nSkipping {spectrum_path}: Value error"
if raise_error:
logger.error(msg)
raise AstroDBError(msg)
else:
logger.warning(msg)
return False

# check for NaNs
nan_check: np.ndarray = ~np.isnan(flux) & ~np.isnan(wave)
wave = wave[nan_check]
flux = flux[nan_check]
if not len(wave):
msg = f"Skipping {spectrum_path}: spectrum is all NaNs"
if raise_error:
logger.error(msg)
raise AstroDBError(msg)
else:
logger.warning(msg)
return False

if show_plot:
plt.plot(wave, flux)
plt.show()

return True


def find_spectra(
db: Database,
source: str,
Expand Down
217 changes: 94 additions & 123 deletions tests/test_spectra_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,83 +7,111 @@
from simple.utils.spectra import (
ingest_spectrum,
# ingest_spectrum_from_fits,
spectrum_plottable,
)


@pytest.mark.filterwarnings(
"ignore", message=".*Note: astropy.io.fits uses zero-based indexing.*"
"ignore",
message=".*SAWarning: Column 'Spectra.reference' is marked as a member of the primary key for table 'Spectra'.*",
)
@pytest.mark.filterwarnings(
"ignore", message=".*'datfix' made the change 'Set MJD-OBS to.*"
"ignore", message=".*'kiwi': No known catalog could be found.*"
)
@pytest.mark.filterwarnings(
"ignore",
message=(
".*'erg/cm2/s/A' contains multiple slashes, "
"which is discouraged by the FITS standard.*",
),
@pytest.mark.parametrize(
"test_input, message",
[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a note, if you want to prevent auto-formatters from reformatting text (for example if you prefer something more compact even if it's not Black/ruff/darker-standard) you can preface the block with # fmt: off and then end it with # fmt: on

(
{
"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
},
"Value required for regime",
), # missing regime
(
{
"source": "apple",
"regime": "nir",
"instrument": "SpeX",
"obs_date": "2020-01-01",
},
"Value required for telescope",
), # missing telescope
(
{
"source": "apple",
"regime": "nir",
"telescope": "IRTF",
"obs_date": "2020-01-01",
},
"Value required for instrument",
), # missing instrument
(
{
"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "nir",
"obs_date": "2020-01-01",
},
"NOT NULL constraint failed: Spectra.reference",
), # missing reference
(
{
"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "nir",
"obs_date": "2020-01-01",
"reference": "Ref 5",
},
"FOREIGN KEY constraint failed",
), # invalid reference
(
{
"source": "kiwi",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "nir",
"obs_date": "2020-01-01",
"reference": "Ref 1",
},
"No unique source match for kiwi in the database",
), # invalid source
(
{
"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "nir",
"reference": "Ref 1",
},
"Invalid date received: None",
), # missing date
(
{
"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "fake regime",
"obs_date": "2020-01-01",
"reference": "Ref 1",
},
"FOREIGN KEY constraint failed",
), # invalid regime
],
)
@pytest.mark.parametrize("test_input, message", [
({"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
}, "Value required for regime"), # missing regime
({"source": "apple",
"regime": "nir",
"instrument": "SpeX",
"obs_date": "2020-01-01",
}, "Value required for telescope"), # missing telescope
({"source": "apple",
"regime": "nir",
"telescope": "IRTF",
"obs_date": "2020-01-01",
}, "Value required for instrument"), # missing instrument
({"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "nir",
"obs_date": "2020-01-01",
}, "NOT NULL constraint failed: Spectra.reference"), # missing reference
({"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "nir",
"obs_date": "2020-01-01",
"reference": "Ref 5",
}, "FOREIGN KEY constraint failed"), # invalid reference
({"source": "kiwi",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "nir",
"obs_date": "2020-01-01",
"reference": "Ref 1",
}, "No unique source match for kiwi in the database"), # invalid source
({"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "nir",
"reference": "Ref 1",
}, "Invalid date received: None"), # missing date
({"source": "apple",
"telescope": "IRTF",
"instrument": "SpeX",
"mode": "Prism",
"regime": "fake regime",
"obs_date": "2020-01-01",
"reference": "Ref 1",
}, "FOREIGN KEY constraint failed"), # invalid regime
])
def test_ingest_spectrum_errors(temp_db, test_input, message):
# Test for ingest_spectrum that is expected to return errors

# Prepare parameters to send to ingest_spectrum
spectrum = "https://bdnyc.s3.amazonaws.com/tests/U10176.fits"
spectrum = "https://bdnyc.s3.amazonaws.com/IRS/2MASS+J03552337%2B1133437.fits"
parameters = {"db": temp_db, "spectrum": spectrum}
parameters.update(test_input)

Expand All @@ -98,19 +126,8 @@ def test_ingest_spectrum_errors(temp_db, test_input, message):
assert message in result["message"]


@pytest.mark.filterwarnings("ignore:Verification")
@pytest.mark.filterwarnings("ignore", message=".*Card 'AIRMASS' is not FITS standard.*")
@pytest.mark.filterwarnings(
"ignore:Note"
) # : astropy.io.fits uses zero-based indexing.
@pytest.mark.filterwarnings("ignore:'datfix' made the change 'Set MJD-OBS to")
@pytest.mark.filterwarnings(
"ignore:'erg/cm2/s/A' contains multiple slashes,"
" which is discouraged by the FITS standard"
)
@pytest.mark.filterwarnings("ignore")
def test_ingest_spectrum_works(temp_db):
spectrum = "https://bdnyc.s3.amazonaws.com/tests/U10176.fits"
spectrum = "https://bdnyc.s3.amazonaws.com/IRS/2MASS+J03552337%2B1133437.fits"
result = ingest_spectrum(
temp_db,
source="banana",
Expand All @@ -123,49 +140,3 @@ def test_ingest_spectrum_works(temp_db):
mode="Prism",
)
assert result["added"] is True


@pytest.mark.filterwarnings("ignore:Invalid 'BLANK' keyword in header.")
@pytest.mark.filterwarnings("ignore:'datfix' made the change 'Set MJD-OBS to")
@pytest.mark.filterwarnings("ignore:The WCS transformation has more axes")
@pytest.mark.filterwarnings("ignore:'cdfix' made the change 'Success'")
@pytest.mark.filterwarnings("ignore:MJD-OBS =")
@pytest.mark.filterwarnings(
"ignore",
message=(
"'erg/cm2/s/A' contains multiple slashes, "
"which is discouraged by the FITS standard.*",
),
)
@pytest.mark.filterwarnings("ignore")
@pytest.mark.parametrize(
"file",
[
"https://s3.amazonaws.com/bdnyc/optical_spectra/2MASS1538-1953_tell.fits",
"https://s3.amazonaws.com/bdnyc/spex_prism_lhs3003_080729.txt",
"https://bdnyc.s3.amazonaws.com/IRS/2351-2537_IRS_spectrum.dat",
],
)
def test_spectrum_plottable_false(file):
with pytest.raises(AstroDBError) as error_message:
spectrum_plottable(file)
assert "unable to load file as Spectrum1D object" in str(error_message.value)

result = spectrum_plottable(file, raise_error=False)
assert result is False


@pytest.mark.parametrize(
"file",
[
(
"https://bdnyc.s3.amazonaws.com/SpeX/Prism/"
"2MASS+J04510093-3402150_2012-09-27.fits"
),
"https://bdnyc.s3.amazonaws.com/IRS/2MASS+J23515044-2537367.fits",
"https://bdnyc.s3.amazonaws.com/optical_spectra/vhs1256b_opt_Osiris.fits",
],
)
def test_spectrum_plottable_true(file):
result = spectrum_plottable(file)
assert result is True