Skip to content

Commit

Permalink
Fix type errors for matplotlib 3.9.1
Browse files Browse the repository at this point in the history
  • Loading branch information
lukeshingles committed Jul 4, 2024
1 parent 833d3d9 commit c5305ae
Show file tree
Hide file tree
Showing 17 changed files with 96 additions and 31 deletions.
4 changes: 3 additions & 1 deletion artistools/estimators/plotestimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,9 @@ def make_plot(
# tight_layout={"pad": 0.2, "w_pad": 0.0, "h_pad": 0.0},
)
if len(plotlist) == 1:
axes = [axes]
axes = np.array([axes])

assert isinstance(axes, np.ndarray)

# ax.xaxis.set_minor_locator(ticker.MultipleLocator(base=5))
if not args.hidexlabel:
Expand Down
5 changes: 3 additions & 2 deletions artistools/gsinetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ def plot_qdot(
tight_layout={"pad": 0.4, "w_pad": 0.0, "h_pad": 0.0},
)
if nrows == 1:
axes = [axes]
axes = np.array([axes])

assert isinstance(axes, np.ndarray)
axis = axes[0]

# axis.set_ylim(bottom=1e7, top=2e10)
Expand Down Expand Up @@ -235,7 +236,7 @@ def plot_cell_abund_evolution(
)
fig.subplots_adjust(top=0.8)
# axis.set_xscale('log')

assert isinstance(axes, np.ndarray)
axes[-1].set_xlabel("Time [days]")
axis = axes[0]
print("nuc gsi_abund artis_abund")
Expand Down
4 changes: 3 additions & 1 deletion artistools/inputmodel/downscale3dgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def make_downscaled_3d_grid(
print("matplotlib not found, skipping")
return outputfolder

fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(6.8 * 1.5, 4.8))
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(6.8 * 1.5, 4.8))
assert isinstance(axes, np.ndarray)
(ax1, ax2) = axes

middle_ind = int(rho.shape[0] / 2)
im1 = ax1.imshow(rho[middle_ind, :, :])
Expand Down
2 changes: 2 additions & 0 deletions artistools/inputmodel/plotdensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import argcomplete
import matplotlib.pyplot as plt
import numpy as np
import polars as pl

import artistools as at
Expand Down Expand Up @@ -40,6 +41,7 @@ def main(args: argparse.Namespace | None = None, argsraw: t.Sequence[str] | None
figsize=(8, 8),
tight_layout={"pad": 0.4, "w_pad": 0.0, "h_pad": 0.0},
)
assert isinstance(axes, np.ndarray)

if not args.modelpath:
args.modelpath = ["."]
Expand Down
5 changes: 2 additions & 3 deletions artistools/inputmodel/recombinationenergy.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,8 @@ def get_particles_recomb_nuc_energy(traj_root, dfbinding):
marker=".",
)
ax.legend(loc="best", handlelength=2, frameon=False, numpoints=1)
ax.xlabel("Ye")
# ax.ylabel('eV / g')
ax.yscale("log")
ax.set_xlabel("Ye")
ax.set_yscale("log")
ax.set_ylim(bottom=1e24, top=1e33)

fig.savefig("recomb.pdf", format="pdf")
Expand Down
46 changes: 35 additions & 11 deletions artistools/lightcurve/plotlightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
import sys
import typing as t
from collections.abc import Iterable
from collections.abc import Sequence
from pathlib import Path

import argcomplete
import matplotlib as mpl
import matplotlib.axes as mplax
import matplotlib.cm as mplcm
import matplotlib.colors as mplcolors
import matplotlib.pyplot as plt
Expand All @@ -24,7 +26,14 @@
color_list = list(plt.get_cmap("tab20")(np.linspace(0, 1.0, 20)))


def plot_deposition_thermalisation(axis, axistherm, modelpath, modelname, plotkwargs, args: argparse.Namespace) -> None:
def plot_deposition_thermalisation(
axis: mplax.Axes,
axistherm: mplax.Axes | None,
modelpath: str | Path,
modelname: str,
plotkwargs: dict[str, t.Any],
args: argparse.Namespace,
) -> None:
# if args.logscalex:
# axistherm.set_xscale("log")

Expand Down Expand Up @@ -65,8 +74,8 @@ def plot_deposition_thermalisation(axis, axistherm, modelpath, modelname, plotkw
# 'color': color_total,
# }))

color_gamma = axis._get_lines.get_next_color() # noqa: SLF001
color_gamma = axis._get_lines.get_next_color() # noqa: SLF001
color_gamma = axis._get_lines.get_next_color() # type: ignore[attr-defined] # noqa: SLF001
color_gamma = axis._get_lines.get_next_color() # type: ignore[attr-defined] # noqa: SLF001

# axis.plot(depdata['tmid_days'], depdata['eps_gamma_Lsun'] * 3.826e33, **dict(
# plotkwargs, **{
Expand All @@ -88,7 +97,7 @@ def plot_deposition_thermalisation(axis, axistherm, modelpath, modelname, plotkw
},
)

color_beta = axis._get_lines.get_next_color() # noqa: SLF001
color_beta = axis._get_lines.get_next_color() # type: ignore[attr-defined] # noqa: SLF001

if "eps_elec_Lsun" in depdata:
axis.plot(
Expand Down Expand Up @@ -174,6 +183,7 @@ def plot_deposition_thermalisation(axis, axistherm, modelpath, modelname, plotkw
)

if args.plotthermalisation:
assert axistherm is not None
f_gamma = depdata["gammadep_Lsun"] / depdata["eps_gamma_Lsun"]
axistherm.plot(
depdata["tmid_days"],
Expand Down Expand Up @@ -276,7 +286,7 @@ def plot_artis_lightcurve(
frompackets: bool = False,
maxpacketfiles: int | None = None,
axistherm=None,
directionbins: t.Sequence[int] | None = None,
directionbins: Sequence[int] | None = None,
average_over_phi: bool = False,
average_over_theta: bool = False,
usedegrees: bool = False,
Expand Down Expand Up @@ -475,7 +485,7 @@ def plot_artis_lightcurve(


def make_lightcurve_plot(
modelpaths: t.Sequence[str | Path],
modelpaths: Sequence[str | Path],
filenameout: str | Path,
frompackets: bool = False,
escape_type: t.Literal["TYPE_RPKT", "TYPE_GAMMA"] = "TYPE_RPKT",
Expand Down Expand Up @@ -693,7 +703,9 @@ def create_axes(args):
figsize=(args.figwidth, args.figheight),
tight_layout={"pad": 3.0, "w_pad": 0.6, "h_pad": 0.6},
) # (6.2 * 3, 9.4 * 3)

if args.subplots:
assert isinstance(ax, np.ndarray)
ax = ax.flatten()

return fig, ax
Expand Down Expand Up @@ -829,7 +841,9 @@ def make_colorbar_viewingangles(phi_viewing_angle_bins, scaledmap, args, fig=Non
cbar.update_ticks()


def make_band_lightcurves_plot(modelpaths, filternames_conversion_dict, outputfolder, args: argparse.Namespace) -> None:
def make_band_lightcurves_plot(
modelpaths: Sequence[str | Path], filternames_conversion_dict: dict, outputfolder, args: argparse.Namespace
) -> None:
# angle_names = [0, 45, 90, 180]
# plt.style.use('dark_background')

Expand Down Expand Up @@ -916,11 +930,12 @@ def make_band_lightcurves_plot(modelpaths, filternames_conversion_dict, outputfo
if len(angles) > 1 and index > 0:
print("already plotted reflightcurve")
else:
assert isinstance(ax, mplax.Axes)
define_colours_list = args.refspeccolors
markers = args.refspecmarkers
for i, reflightcurve in enumerate(args.reflightcurves):
plot_lightcurve_from_refdata(
band_lightcurve_data.keys(),
list(band_lightcurve_data.keys()),
reflightcurve,
define_colours_list[i],
markers[i],
Expand Down Expand Up @@ -1131,13 +1146,20 @@ def colour_evolution_plot(modelpaths, filternames_conversion_dict, outputfolder,


def plot_lightcurve_from_refdata(
filter_names, lightcurvefilename, color, marker, filternames_conversion_dict, ax, plotnumber
):
filter_names: Sequence[str],
lightcurvefilename: Path | str,
color,
marker,
filternames_conversion_dict,
ax: np.ndarray | mplax.Axes,
plotnumber: int,
) -> str | None:
from extinction import apply
from extinction import ccm89

lightcurve_data, metadata = at.lightcurve.read_reflightcurve_band_data(lightcurvefilename)
linename = metadata["label"] if plotnumber == 0 else None
assert linename is None or isinstance(linename, str)
filterdir = Path(at.get_config()["path_artistools_dir"], "data/filters/")

filter_data = {}
Expand Down Expand Up @@ -1177,6 +1199,7 @@ def plot_lightcurve_from_refdata(
else:
print("WARNING: did not correct for reddening")
if len(filter_names) > 1:
assert isinstance(ax, np.ndarray)
ax[axnumber].plot(
filter_data[filter_name]["time"],
filter_data[filter_name]["magnitude"],
Expand All @@ -1185,6 +1208,7 @@ def plot_lightcurve_from_refdata(
color=color,
)
else:
assert isinstance(ax, mplax.Axes)
ax.plot(
filter_data[filter_name]["time"],
filter_data[filter_name]["magnitude"],
Expand Down Expand Up @@ -1566,7 +1590,7 @@ def addargs(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--legendframeon", action="store_true", help="Frame on in legend")


def main(args: argparse.Namespace | None = None, argsraw: t.Sequence[str] | None = None, **kwargs: t.Any) -> None:
def main(args: argparse.Namespace | None = None, argsraw: Sequence[str] | None = None, **kwargs: t.Any) -> None:
"""Plot ARTIS light curve."""
if args is None:
parser = argparse.ArgumentParser(formatter_class=at.CustomArgHelpFormatter, description=__doc__)
Expand Down
4 changes: 3 additions & 1 deletion artistools/linefluxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,10 @@ def make_flux_ratio_plot(args: argparse.Namespace) -> None:
),
tight_layout={"pad": 0.2, "w_pad": 0.0, "h_pad": 0.0},
)
assert isinstance(axes, np.ndarray)

if nrows == 1:
axes = [axes]
axes = np.array([axes])

axis = axes[0]
axis.set_yscale("log")
Expand Down Expand Up @@ -690,6 +691,7 @@ def make_emitting_regions_plot(args: argparse.Namespace) -> None:
),
tight_layout={"pad": 0.2, "w_pad": 0.0, "h_pad": 0.2},
)
assert isinstance(axis, mplax.Axes)

for refdataindex in range(len(refdatafilenames)):
timeindex = np.abs(refdatatimes[refdataindex] - tmid).argmin()
Expand Down
6 changes: 5 additions & 1 deletion artistools/nltepops/plotnltepops.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def make_plot_populations_with_time_or_velocity(modelpaths: list[Path | str], ar
figsize=(at.get_config()["figwidth"] * 2 * cols, at.get_config()["figwidth"] * 0.85 * rows),
tight_layout={"pad": 2.0, "w_pad": 0.2, "h_pad": 0.2},
)
assert isinstance(ax, np.ndarray)
if args.subplots:
ax = ax.flatten()

Expand All @@ -439,6 +440,7 @@ def make_plot_populations_with_time_or_velocity(modelpaths: list[Path | str], ar
axis.text(xmax * 0.85, ymin * 50, f"{args.timedayslist[plotnumber]} days")
ax[0].legend(loc="best", frameon=True, fontsize="x-small", ncol=1)
else:
assert isinstance(ax, mplax.Axes)
ax.legend(loc="best", frameon=True, fontsize="x-small", ncol=1)
ax.set_yscale("log")

Expand Down Expand Up @@ -562,7 +564,9 @@ def make_plot(modelpath, atomic_number, ion_stages_displayed, mgilist, timestep,
)

if nrows == 1:
axes = [axes]
axes = np.array([axes])

assert isinstance(axes, np.ndarray)

prev_ion_stage = -1
assert len(mgilist) > 0
Expand Down
2 changes: 2 additions & 0 deletions artistools/nonthermal/leptontransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math

import matplotlib.pyplot as plt
import numpy as np

CONST_EV_IN_J = 1.602176634e-19 # 1 eV [J]

Expand Down Expand Up @@ -135,6 +136,7 @@ def main() -> None:
fig, axes = plt.subplots(
nrows=2, ncols=1, sharex=False, figsize=(5, 8), tight_layout={"pad": 0.5, "w_pad": 0.0, "h_pad": 1.0}
)
assert isinstance(axes, np.ndarray)
axes[0].plot(arr_dist, arr_energy_ev)
axes[0].set_xlabel(r"Distance [m]")
axes[0].set_ylabel(r"Energy [eV]")
Expand Down
4 changes: 3 additions & 1 deletion artistools/nonthermal/plotnonthermal.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def make_plot(modelpaths: list[Path], args: argparse.Namespace) -> None:
)

if nplots == 1:
axes = [axes]
axes = np.array([axes])

assert isinstance(axes, np.ndarray)

if args.kf1992spec:
kf92spec = pd.read_csv(Path(modelpaths[0], "KF1992spec-fig1.txt"), header=None, names=["e_kev", "log10_y"])
Expand Down
4 changes: 3 additions & 1 deletion artistools/plotspherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ def plot_spherical(
)

if len(plotvars) == 1:
axes = (axes,)
axes = np.array([axes])

assert isinstance(axes, np.ndarray)

# for ax, axcbar, plotvar in zip(axes[::2], axes[1::2], plotvars):
for ax, plotvar in zip(axes, plotvars, strict=False):
Expand Down
14 changes: 11 additions & 3 deletions artistools/plottools.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,17 +150,25 @@ def set_axis_properties(ax: t.Iterable[mplax.Axes] | mplax.Axes, args: argparse.


def set_axis_labels(
fig: mplfig.Figure, ax: mplax.Axes, xlabel: str, ylabel: str, labelfontsize: int | None, args: argparse.Namespace
):
fig: mplfig.Figure,
ax: mplax.Axes | np.ndarray,
xlabel: str,
ylabel: str,
labelfontsize: int | None,
args: argparse.Namespace,
) -> None:
if args.subplots:
fig.text(0.5, 0.02, xlabel, ha="center", va="center")
fig.text(0.02, 0.5, ylabel, ha="center", va="center", rotation="vertical")
else:
assert isinstance(ax, mplax.Axes)
ax.set_xlabel(xlabel, fontsize=labelfontsize)
ax.set_ylabel(ylabel, fontsize=labelfontsize)


def imshow_init_for_artis_grid(ngrid: int, vmax: float, plot_variable_3d_array: npt.NDArray, plot_axes: str = "xy"):
def imshow_init_for_artis_grid(
ngrid: int, vmax: float, plot_variable_3d_array: npt.NDArray, plot_axes: str = "xy"
) -> tuple[npt.NDArray, tuple[float, float, float, float]]:
# ngrid = round(len(model['inputcellid']) ** (1./3.))
extentdict = {"left": -vmax, "right": vmax, "bottom": vmax, "top": -vmax}
extent = extentdict["left"], extentdict["right"], extentdict["bottom"], extentdict["top"]
Expand Down
12 changes: 11 additions & 1 deletion artistools/radfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,12 @@ def plot_celltimestep(modelpath, timestep, outputfile, xmin, xmax, modelgridinde
tight_layout={"pad": 0.2, "w_pad": 0.0, "h_pad": 0.0},
)

axis = axes if nrows == 1 else axes[-1]
if isinstance(axes, mplax.Axes):
axes = np.array([axes])
assert isinstance(axes, np.ndarray)
axis = axes[-1]

assert isinstance(axis, mplax.Axes)
ymax = 0.0

xlist, yvalues = get_fullspecfittedfield(radfielddata, xmin, xmax, modelgridindex=modelgridindex, timestep=timestep)
Expand Down Expand Up @@ -895,6 +900,11 @@ def plot_timeevolution(modelpath, outputfile, modelgridindex, args: argparse.Nam
tight_layout={"pad": 0.2, "w_pad": 0.0, "h_pad": 0.0},
)

if isinstance(axes, mplax.Axes):
axes = np.array([axes])

assert isinstance(axes, np.ndarray)

timestep = at.get_timestep_of_timedays(modelpath, 330)
time_days = at.get_timestep_time(modelpath, timestep)

Expand Down
7 changes: 4 additions & 3 deletions artistools/spectra/plotspectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def plot_artis_spectrum(

def make_spectrum_plot(
speclist: t.Collection[Path | str],
axes: t.Sequence[mplax.Axes],
axes: t.Sequence[mplax.Axes] | np.ndarray,
filterfunc: t.Callable[[npt.NDArray[np.floating] | pl.Series], npt.NDArray[np.floating]] | None,
args,
scale_to_peak: float | None = None,
Expand Down Expand Up @@ -948,7 +948,7 @@ def make_contrib_plot(
# ax.pcolormesh(xi, yi, zi.reshape(xi.shape), shading='gouraud', cmap=plt.cm.BuGn_r)


def make_plot(args) -> tuple[mplfig.Figure, list[mplax.Axes], pl.DataFrame]:
def make_plot(args) -> tuple[mplfig.Figure, np.ndarray, pl.DataFrame]:
# font = {'size': 16}
# mpl.rc('font', **font)

Expand All @@ -975,7 +975,8 @@ def make_plot(args) -> tuple[mplfig.Figure, list[mplax.Axes], pl.DataFrame]:
tight_layout={"pad": 0.2, "w_pad": 0.0, "h_pad": 0.0},
)

axes = [axes] if nrows == 1 else list(axes)
axes = np.array([axes]) if nrows == 1 else np.array(axes)
assert isinstance(axes, np.ndarray)

filterfunc = at.get_filterfunc(args)

Expand Down
Loading

0 comments on commit c5305ae

Please sign in to comment.