Skip to content

Commit

Permalink
Merge pull request #79 from mdecleir/refine_plotting
Browse files Browse the repository at this point in the history
Refine plotting
  • Loading branch information
karllark authored Mar 23, 2021
2 parents 5be6588 + 7c20f04 commit c70565b
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 85 deletions.
Binary file removed measure_extinction/data/hd229238_hd204172_ext.fits
Binary file not shown.
73 changes: 35 additions & 38 deletions measure_extinction/extdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from astropy.io import fits
from scipy.optimize import curve_fit

from dust_extinction.parameter_averages import F04
from astropy.modeling.powerlaws import PowerLaw1D
from astropy.modeling import Parameter
from astropy.modeling.fitting import LevMarLSQFitter
Expand Down Expand Up @@ -98,7 +97,7 @@ def _get_column_val(column):
return float(column)


def AverageExtData(extdatas):
def AverageExtData(extdatas, min_number=3):
"""
Generate the average extinction curve from a list of ExtData objects
Expand All @@ -107,6 +106,9 @@ def AverageExtData(extdatas):
extdatas : list of ExtData objects
list of extinction curves to average
min_number : int [default=3]
minimum number of extinction curves that are required to measure the average extinction; if less than min_number of curves are available at certain wavelengths, the average extinction will still be calculated, but the number of points (npts) at those wavelengths will be set to zero (e.g. used in the plotting)
Returns
-------
aveext: ExtData object
Expand All @@ -117,7 +119,8 @@ def AverageExtData(extdatas):
names = []
bwaves = []
for extdata in extdatas:
# check the data type of the extinction curves, and convert if needed
# check the data type of the extinction curve, and convert if needed
# the average curve must be calculated from the A(lambda)/A(V) curves
if extdata.type != "alav" or extdata.type != "alax":
extdata.trans_elv_alav()

Expand All @@ -136,7 +139,7 @@ def AverageExtData(extdatas):
aveext.type = extdatas[0].type
aveext.type_rel_band = extdatas[0].type_rel_band

# calculate the average for all spectral data
# collect all the extinction data
bexts = {k: [] for k in aveext.names["BAND"]}
for src in keys:
exts = []
Expand All @@ -149,27 +152,39 @@ def AverageExtData(extdatas):
extdata.exts[src][np.where(extdata.npts[src] == 0)] = np.nan
exts.append(extdata.exts[src])

# calculate the average and uncertainties of the band extinction data
if src == "BAND":
aveext.exts["BAND"] = []
aveext.npts["BAND"] = []
aveext.stds["BAND"] = []
aveext.uncs["BAND"] = []
for name in aveext.names["BAND"]:
aveext.exts["BAND"].append(np.nanmean(bexts[name]))
aveext.npts["BAND"].append(len(bexts[name]))
aveext.exts["BAND"] = np.zeros(len(names))
aveext.npts["BAND"] = np.zeros(len(names))
aveext.stds["BAND"] = np.zeros(len(names))
aveext.uncs["BAND"] = np.zeros(len(names))
for i, name in enumerate(aveext.names["BAND"]):
aveext.exts["BAND"][i] = np.nanmean(bexts[name])
aveext.npts["BAND"][i] = len(bexts[name])

# calculation of the standard deviation (this is the spread of the sample around the population mean)
aveext.stds["BAND"].append(np.nanstd(bexts[name], ddof=1))
aveext.stds["BAND"][i] = np.nanstd(bexts[name], ddof=1)

# calculation of the standard error of the average (the standard error of the sample mean is an estimate of how far the sample mean is likely to be from the population mean)
aveext.uncs["BAND"] = aveext.stds["BAND"] / np.sqrt(aveext.npts["BAND"])

# calculate the average and uncertainties of the spectral extinction data
else:
aveext.exts[src] = np.nanmean(exts, axis=0)
aveext.npts[src] = np.sum(~np.isnan(exts), axis=0)
aveext.stds[src] = np.nanstd(exts, axis=0, ddof=1)
aveext.uncs[src] = aveext.stds[src] / np.sqrt(aveext.npts[src])

# take out the data points where less than a certain number of values was averaged, and give a warning
if min_number > 1:
aveext.npts[src][aveext.npts[src] < min_number] = 0
warnings.warn(
"The minimum number of "
+ str(min_number)
+ " extinction curves was not reached for certain wavelengths, and the number of points (npts) for those wavelengths was set to 0.",
UserWarning,
)

return aveext


Expand Down Expand Up @@ -490,14 +505,12 @@ def calc_RV(self):

def trans_elv_alav(self, av=None, akav=0.112):
"""
Transform E(lambda-V) to A(lambda)/A(V) by normalizing to
A(V) and adding 1. Default is to calculate A(V) from the
input elx curve. If A(V) value is passed, use that one instead.
Transform E(lambda-V) to A(lambda)/A(V) by normalizing to A(V) and adding 1. If A(V) is in the columns of the extdata object, use that value. If A(V) is passed explicitly, use that value instead. If no A(V) is available, calculate A(V) from the input elx curve.
Parameters
----------
av : float [default = None]
value of A(V) to use - otherwise calculate it
value of A(V) to use - otherwise take it from the columns of the object or calculate it
akav : float [default = 0.112]
Value of A(K)/A(V), only needed if A(V) has to be calculated from the K-band extinction
Expand All @@ -514,10 +527,11 @@ def trans_elv_alav(self, av=None, akav=0.112):
)
else:
if av is None:
self.calc_AV(akav=akav)
if "AV" not in self.columns.keys():
self.calc_AV(akav=akav)
av = _get_column_val(self.columns["AV"])

for curname in self.exts.keys():
av = _get_column_val(self.columns["AV"])
self.exts[curname] = (self.exts[curname] / av) + 1
self.uncs[curname] /= av
# update the extinction curve type
Expand Down Expand Up @@ -1091,19 +1105,8 @@ def plot(
fontsize for plot
"""
if alax:
# compute A(V) if it is not available
if "AV" not in self.columns.keys():
self.trans_elv_alav()
av = _get_column_val(self.columns["AV"])
if self.type_rel_band != "V": # not sure if this works (where is RV given?)
# use F04 model to convert AV to AX
rv = _get_column_val(self.columns["RV"])
emod = F04(rv)
(indx,) = np.where(self.type_rel_band == self.names["BAND"])
axav = emod(self.waves["BAND"][indx[0]])
else:
axav = 1.0
ax = axav * av
# transform the extinctions from E(lambda-V) to A(lambda)/A(V)
self.trans_elv_alav()

for curtype in self.waves.keys():
# do not plot the excluded data type(s)
Expand All @@ -1115,13 +1118,6 @@ def plot(
y = self.exts[curtype]
yu = self.uncs[curtype]

if (
alax and self.type == "elx"
): # in the case A(V) was already available and the curve has not been transformed yet
# convert E(lambda-X) to A(lambda)/A(X)
y = (y / ax) + 1.0
yu /= ax

y = y / normval + yoffset
yu = yu / normval

Expand Down Expand Up @@ -1160,6 +1156,7 @@ def plot(
color=annotate_color,
horizontalalignment="left",
rotation=annotate_rotation,
rotation_mode="anchor",
fontsize=fontsize,
)

Expand Down
83 changes: 61 additions & 22 deletions measure_extinction/plotting/plot_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,22 @@
import numpy as np
import astropy.units as u
import pandas as pd
import os

from measure_extinction.extdata import ExtData
from measure_extinction.utils.calc_ext import calc_ave_ext
from dust_extinction.parameter_averages import CCM89


def plot_average(
starpair_list,
path,
filename="average_ext.fits",
ax=None,
extmodels=False,
fitmodel=False,
HI_lines=False,
range=None,
exclude=[],
log=False,
spread=False,
annotate_key=None,
annotate_wave_range=None,
Expand All @@ -34,11 +35,11 @@ def plot_average(
Parameters
----------
starpair_list : list of strings
List of star pairs for which to calculate and plot the average extinction curve, in the format "reddenedstarname_comparisonstarname" (no spaces)
path : string
Path to the data files
Path to the average extinction curve fits file
filename : string [default="average_ext.fits"]
Name of the average extinction curve fits file
ax : AxesSubplot [default=None]
Axes of plot on which to add the average extinction curve if pdf=False
Expand All @@ -58,14 +59,17 @@ def plot_average(
exclude : list of strings [default=[]]
List of data type(s) to exclude from the plot (e.g., IRS)
log : boolean [default=False]
Whether or not to plot the wavelengths on a log-scale
spread : boolean [default=False]
Whether or not to offset the average extinction curve from the other curves
Whether or not to offset the average extinction curve from the other curves (only relevant when pdf=False and ax=None)
annotate_key : string [default=None]
type of data for which to annotate text (e.g., SpeX_LXD)
type of data for which to annotate text (e.g., SpeX_LXD) (only relevant when pdf=False and ax=None)
annotate_wave_range : list of 2 floats [default=None]
min/max wavelength range for the annotation of the text
min/max wavelength range for the annotation of the text (only relevant when pdf=False and ax=None)
pdf : boolean [default=False]
- If False, the average extinction curve will be overplotted on the current plot (defined by ax)
Expand All @@ -75,11 +79,18 @@ def plot_average(
-------
Plots the average extinction curve
"""
# calculate the average extinction curve
calc_ave_ext(starpair_list, path)

# read in the average extinction curve
average = ExtData(path + "average_ext.fits")
# read in the average extinction curve (if it exists)
if os.path.isfile(path + filename):
average = ExtData(path + filename)
else:
warnings.warn(
"An average extinction curve with the name "
+ filename
+ " could not be found in "
+ path
+ ". Please calculate the average extinction curve first with the calc_ave_ext function in measure_extinction/utils/calc_ext.py.",
UserWarning,
)

# make a new plot if requested
if pdf:
Expand Down Expand Up @@ -116,14 +127,17 @@ def plot_average(
zoom(ax, range)

# finish configuring the plot
ax.set_title("average", fontsize=50)
ax.set_xscale("log")
if log:
ax.set_xscale("log")
plt.xlabel(r"$\lambda$ [$\mu m$]", fontsize=1.5 * fontsize)
ax.set_ylabel(
average._get_ext_ytitle(ytype=average.type), fontsize=1.5 * fontsize
)
fig.savefig(path + "average_ext.pdf", bbox_inches="tight")

# return the figure and axes for additional manipulations
return fig, ax

else:
if spread:
yoffset = -0.3
Expand Down Expand Up @@ -258,6 +272,7 @@ def plot_fitmodel(extdata, yoffset=0, res=False):
plt.axhline(ls="--", c="k", alpha=0.5)
plt.axhline(y=0.05, ls=":", c="k", alpha=0.5)
plt.axhline(y=-0.05, ls=":", c="k", alpha=0.5)
plt.ylim(-0.1, 0.1)
plt.ylabel("residual")

else:
Expand Down Expand Up @@ -359,6 +374,9 @@ def plot_multi_extinction(
range=None,
spread=False,
exclude=[],
log=False,
text_offsets=[],
text_angles=[],
pdf=False,
):
"""
Expand Down Expand Up @@ -396,6 +414,15 @@ def plot_multi_extinction(
exclude : list of strings [default=[]]
List of data type(s) to exclude from the plot (e.g., IRS)
log : boolean [default=False]
Whether or not to plot the wavelengths on a log-scale
text_offsets : list of floats [default=[]]
List of the same length as starpair_list with offsets for the annotated text
text_angles : list of integers [default=[]]
List of the same length as starpair_list with rotation angles for the annotated text
pdf : boolean [default=False]
Whether or not to save the figure as a pdf file
Expand All @@ -419,13 +446,19 @@ def plot_multi_extinction(
fig, ax = plt.subplots(figsize=(15, len(starpair_list) * 1.25))
colors = plt.get_cmap("tab10")

# set default text offsets and angles
if text_offsets == []:
text_offsets = np.full(len(starpair_list), 0.2)
if text_angles == []:
text_angles = np.full(len(starpair_list), 10)

for i, starpair in enumerate(starpair_list):
# read in the extinction curve data
extdata = ExtData("%s%s_ext.fits" % (path, starpair.lower()))

# spread out the curves if requested
if spread:
yoffset = 0.3 * i
yoffset = 0.25 * i
else:
yoffset = 0.0

Expand Down Expand Up @@ -455,8 +488,8 @@ def plot_multi_extinction(
annotate_key=ann_key,
annotate_wave_range=ann_range,
annotate_text=extdata.red_file.split(".")[0].upper(),
annotate_yoffset=-0.1,
annotate_rotation=-15,
annotate_yoffset=text_offsets[i],
annotate_rotation=text_angles[i],
annotate_color=colors(i % 10),
)

Expand All @@ -477,7 +510,6 @@ def plot_multi_extinction(
# plot the average extinction curve if requested
if average:
plot_average(
starpair_list,
path,
ax=ax,
extmodels=extmodels,
Expand All @@ -501,16 +533,23 @@ def plot_multi_extinction(
outname = outname.replace(".pdf", "_zoom.pdf")

# finish configuring the plot
ax.set_xscale("log")
if log:
ax.set_xscale("log")
ax.set_xlabel(r"$\lambda$ [$\mu m$]", fontsize=1.5 * fontsize)
ax.set_ylabel(extdata._get_ext_ytitle(ytype=extdata.type), fontsize=1.5 * fontsize)
ylabel = extdata._get_ext_ytitle(ytype=extdata.type)
if spread:
ylabel += " + offset"
ax.set_ylabel(ylabel, fontsize=1.5 * fontsize)

# show the figure or save it to a pdf file
if pdf:
fig.savefig(path + outname, bbox_inches="tight")
else:
plt.show()

# return the figure and axes for additional manipulations
return fig, ax


def plot_extinction(
starpair,
Expand Down
Loading

0 comments on commit c70565b

Please sign in to comment.