Skip to content

Commit

Permalink
feat: fixing tests for new interp features
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielaBreitman committed Jul 16, 2024
1 parent f5b5e77 commit 9cbd688
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 34 deletions.
3 changes: 0 additions & 3 deletions .vscode/settings.json

This file was deleted.

84 changes: 56 additions & 28 deletions src/powerbox/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@


def _getbins(bins, coords, log):
try:
# Fails if coords is not a cube / inhomogeneous.
max_radius = np.min([np.max(coords, axis=i) for i in range(coords.ndim)])
except ValueError:
maxs = [np.max(coords, axis=i) for i in range(coords.ndim)]
maxs_flat = []
[maxs_flat.extend(m.ravel()) for m in maxs]
max_radius = np.min(maxs_flat)
if not np.iterable(bins):
try:
# Fails if coords is not a cube / inhomogeneous.
max_radius = np.min([np.max(coords, axis=i) for i in range(coords.ndim)])
except ValueError:
maxs = [np.max(coords, axis=i) for i in range(coords.ndim)]
maxs_flat = []
[maxs_flat.extend(m.ravel()) for m in maxs]
max_radius = np.min(maxs_flat)
if not log:
bins = np.linspace(coords.min(), max_radius, bins + 1)
else:
Expand Down Expand Up @@ -166,10 +166,11 @@ def angular_average(
res = _field_average(indx, field, weights, sumweights)
else:
bins = _getbins(bins, coords_grid, log_bins)
if log_bins:
bins = np.exp((np.log(bins[1:]) + np.log(bins[:-1])) / 2)
else:
bins = (bins[1:] + bins[:-1]) / 2
if bin_ave:
if log_bins:
bins = np.exp((np.log(bins[1:]) + np.log(bins[:-1])) / 2)
else:
bins = (bins[1:] + bins[:-1]) / 2

sample_coords, r_n = _sample_coords_interpolate(
coords, bins, weights, interp_points_generator
Expand Down Expand Up @@ -421,10 +422,17 @@ def _field_average_interpolate(coords, field, bins, weights, sample_coords, r_n)
# Grid is regular + can be ordered only in Cartesian coords.
if isinstance(weights, np.ndarray):
weights = weights.reshape(field.shape)
if not ((weights == 0) | (weights == 1)).all():
warnings.warn(
"Interpolating with non-binary weights is slow.",
UserWarning,
stacklevel=2,
)
else:
field = field * weights
else:
weights = np.ones_like(field) * weights
# Set 0 weights to NaNs
field = field * weights
field[weights == 0] = np.nan
# Rescale the field (see scipy documentation for RegularGridInterpolator)
mean, std = np.nanmean(field), np.max(
Expand All @@ -441,18 +449,37 @@ def _field_average_interpolate(coords, field, bins, weights, sample_coords, r_n)

interped_field = fnc(sample_coords.T) * std + mean
if np.all(np.isnan(interped_field)):
warnings.warn("Interpolator returned all NaNs.", stacklevel=2)
warnings.warn("Interpolator returned all NaNs.", RuntimeWarning, stacklevel=2)
# Average over the spherical shells for each radius / bin value
avged_field = np.array([np.nanmean(interped_field[r_n == b]) for b in bins])
unique_rn, sumweights = np.unique(
r_n[~np.isnan(interped_field)], return_counts=True
)
final_sumweights = []
for b in bins:
if b in unique_rn:
final_sumweights.append(sumweights[unique_rn == b][0])
else:
final_sumweights.append(0)
if not ((weights == 0) | (weights == 1)).all():
fnc = RegularGridInterpolator(
coords,
weights, # Complex data is accepted.
bounds_error=False,
fill_value=np.nan,
)
interped_weights = fnc(sample_coords.T)

avged_field = []

final_sumweights = []
for b in bins:
mbin = np.logical_and(r_n == b, ~np.isnan(interped_field))
avged_field.append(np.sum(interped_field[mbin] * interped_weights[mbin]))
final_sumweights.append(np.sum(interped_weights[mbin]))
avged_field = np.array(avged_field) / final_sumweights
else:
avged_field = np.array([np.nanmean(interped_field[r_n == b]) for b in bins])
unique_rn, sumweights = np.unique(
r_n[~np.isnan(interped_field)],
return_counts=True,
)
final_sumweights = []
for b in bins:
if b in unique_rn:
final_sumweights.append(sumweights[unique_rn == b][0])
else:
final_sumweights.append(0)
return avged_field, np.array(final_sumweights)


Expand Down Expand Up @@ -677,10 +704,11 @@ def angular_average_nd( # noqa: C901
res = np.zeros((len(sumweights), n2), dtype=field.dtype)
if interpolation_method is not None:
bins = _getbins(bins, coords_grid, log_bins)
if log_bins:
bins = np.exp((np.log(bins[1:]) + np.log(bins[:-1])) / 2)
else:
bins = (bins[1:] + bins[:-1]) / 2
if bin_ave:
if log_bins:
bins = np.exp((np.log(bins[1:]) + np.log(bins[:-1])) / 2)
else:
bins = (bins[1:] + bins[:-1]) / 2
res = np.zeros((len(bins), n2), dtype=field.dtype)

if get_variance:
Expand Down
7 changes: 6 additions & 1 deletion tests/test_power.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import warnings

from powerbox import PowerBox, get_power, ignore_zero_absk, ignore_zero_ki, power2delta

Expand Down Expand Up @@ -118,7 +119,11 @@ def test_prefactor_fnc():

def test_k_weights_fnc():
pb = PowerBox(50, dim=3, pk=lambda k: 1.0 * k**-2.0, boxlength=1.0, b=1)
p_ki0, k_ki0 = get_power(pb.delta_x(), pb.boxlength, k_weights=ignore_zero_ki)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="divide by zero encountered in divide"
)
p_ki0, k_ki0 = get_power(pb.delta_x(), pb.boxlength, k_weights=ignore_zero_ki)
p, k = get_power(pb.delta_x(), pb.boxlength, k_weights=ignore_zero_absk)

assert not np.allclose(p, p_ki0)
26 changes: 24 additions & 2 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

import numpy as np
import warnings

from powerbox.powerbox import PowerBox
from powerbox.tools import (
Expand All @@ -12,6 +13,23 @@
)


def test_warn_interp_weights():
x = np.linspace(-3, 3, 40)
P = np.ones(3 * [40])
weights = np.ones_like(P)
weights[2:5] = 0
freq = [x for _ in range(3)]
with pytest.warns(RuntimeWarning):
p_k_lin, k_av_bins_lin = angular_average(
P,
freq,
bins=10,
interpolation_method="linear",
weights=weights,
interp_points_generator=regular_angular_generator,
)


@pytest.mark.parametrize("interpolation_method", [None, "linear"])
def test_angular_avg_nd_3(interpolation_method):
x = np.linspace(-3, 3, 400)
Expand Down Expand Up @@ -120,7 +138,11 @@ def test_interp_w_mu(n):
x = np.linspace(0.0, 3, 40)
if n == 2:
kpar_mesh, kperp_mesh = np.meshgrid(x, x)
theta = np.arctan(kperp_mesh / kpar_mesh)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="divide by zero encountered in divide"
)
theta = np.arctan2(kperp_mesh, kpar_mesh)
mu_mesh = np.cos(theta)
else:
kx_mesh, ky_mesh, kz_mesh = np.meshgrid(x, x, x, indexing="ij")
Expand Down Expand Up @@ -150,7 +172,7 @@ def test_interp_w_mu(n):
def test_error_coords_and_mask():
x = np.linspace(1.0, 3, 40)
kpar_mesh, kperp_mesh = np.meshgrid(x, x)
theta = np.arctan(kperp_mesh / kpar_mesh)
theta = np.arctan2(kperp_mesh, kpar_mesh)
mu_mesh = np.cos(theta)

mask = mu_mesh >= 0.97
Expand Down

0 comments on commit 9cbd688

Please sign in to comment.