diff --git a/src/powerbox/tools.py b/src/powerbox/tools.py index 075815c..bf074ae 100644 --- a/src/powerbox/tools.py +++ b/src/powerbox/tools.py @@ -206,7 +206,8 @@ def _field_average_interpolate(coords, field, bins, weights, angular_resolution= [ np.round(2 * np.pi * bins / angular_resolution), np.ones_like(bins) / angular_resolution, - ] + ], + axis=0, ), dtype=int, ) @@ -307,6 +308,7 @@ def angular_average_nd( get_variance=False, log_bins=False, interpolation_method=None, + angular_resolution=0.1, ): """ Average the first n dimensions of a given field within radial bins. @@ -326,7 +328,7 @@ def angular_average_nd( An array of arbitrary dimension specifying the field to be angularly averaged. coords : list of n arrays - A list of 1D arrays specifying the co-ordinates in each dimension *to be average*. + A list of 1D arrays specifying the co-ordinates in each dimension *to be averaged*. bins : int or array. Specifies the radial bins for the averaged dimensions. Can be an int or array specifying radial bin edges. @@ -394,7 +396,7 @@ def angular_average_nd( if len(coords) != len(field.shape): raise ValueError("coords should be a list of arrays, one for each dimension.") - if n == len(coords): + if n == len(coords) and interpolation_method is None: return angular_average( field, coords, @@ -406,10 +408,10 @@ def angular_average_nd( log_bins=log_bins, ) - coords = _magnitude_grid([c for i, c in enumerate(coords) if i < n]) + coords_grid = _magnitude_grid([c for i, c in enumerate(coords) if i < n]) indx, bins, sumweights = _get_binweights( - coords, weights, bins, average, bin_ave=bin_ave, log_bins=log_bins + coords_grid, weights, bins, average, bin_ave=bin_ave, log_bins=log_bins ) n1 = np.prod(field.shape[:n]) @@ -429,7 +431,13 @@ def angular_average_nd( if get_variance: var[:, i] = _field_variance(indx, fld, res[:, i], w, sumweights) elif interpolation_method == "linear": - res[:, i], sumweights = _field_average_interpolate(coords, fld, bins, w) + res[:, i], sumweights = _field_average_interpolate( + np.array(coords)[:n, ...], + fld.reshape(field.shape[:n]), + bins, + w, + angular_resolution=angular_resolution, + ) if get_variance: # TODO: Implement variance calculation for interpolation var[:, i] = np.zeros_like(res[:, i])