Skip to content

Commit

Permalink
Fix some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
robjmcgibbon committed Sep 3, 2024
1 parent 90f9e4c commit 6a1675b
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 51 deletions.
12 changes: 9 additions & 3 deletions swiftsimio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .reader import *
from .writer import SWIFTWriterDataset
from .snapshot_writer import SWIFTSnapshotWriter
from .masks import SWIFTMask
from .statistics import SWIFTStatisticsFile
from .__version__ import __version__
Expand Down Expand Up @@ -109,5 +109,11 @@ def load_statistics(filename) -> SWIFTStatisticsFile:
return SWIFTStatisticsFile(filename=filename)


# Rename this object to something simpler.
Writer = SWIFTWriterDataset
class Writer:
def __new__(cls, *args, **kwargs):
# Default to SWIFTSnapshotWriter if no filetype is passed
filetype = kwargs.get("filetype", "snapshot")
if filetype == "snapshot":
return SWIFTSnapshotWriter(*args, **kwargs)
# TODO implement other writers
# elif filetype == '
4 changes: 1 addition & 3 deletions swiftsimio/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,13 +961,12 @@ def to_comoving(self):

def compatible_with_comoving(self):
"""
# TODO: Is this the same question as "can be converted to comoving?"
Is this cosmo_array compatible with a comoving cosmo_array?
This is the case if the cosmo_array is comoving, or if the scale factor
exponent is 0 (cosmo_factor.a_factor() == 1)
"""
return self.valid_transform
return self.comoving or (self.cosmo_factor.a_factor == 1.0)

def compatible_with_physical(self):
"""
Expand All @@ -976,7 +975,6 @@ def compatible_with_physical(self):
This is the case if the cosmo_array is physical, or if the scale factor
exponent is 0 (cosmo_factor.a_factor == 1)
"""
# TODO: Isn't this always true? We can convert it to physical if needed?
return (not self.comoving) or (self.cosmo_factor.a_factor == 1.0)

@classmethod
Expand Down
7 changes: 3 additions & 4 deletions swiftsimio/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,10 +553,9 @@ def postprocess_header(self):
self.reduced_lightspeed = None

# Store these separately as self.n_gas = number of gas particles for example
for (
part_number,
part_name,
) in metadata.particle_types.particle_name_underscores.items():
for (part_number, (_, part_name)) in enumerate(
metadata.particle_types.particle_name_underscores.items()
):
try:
setattr(self, f"n_{part_name}", self.num_part[part_number])
except IndexError:
Expand Down
8 changes: 4 additions & 4 deletions swiftsimio/writer.py → swiftsimio/snapshot_writer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Contains functions and objects for creating SWIFT datasets.
Essentially all you want to do is use SWIFTWriterDataset and fill the attributes
Essentially all you want to do is use SWIFTSnapshotWriter and fill the attributes
that are required for each particle type. More information is available in the
README.
"""
Expand Down Expand Up @@ -271,7 +271,7 @@ def get_attributes(self, scale_factor: float) -> dict:

# Find the scale factor associated quantities
a_exp = a_exponents.get(name, 0)
a_factor = scale_factor**a_exp
a_factor = scale_factor ** a_exp

attributes_dict[output_handle] = {
"Conversion factor to CGS (not including cosmological corrections)": [
Expand Down Expand Up @@ -493,7 +493,7 @@ def generate_dataset(
return empty_dataset


class SWIFTWriterDataset(object):
class SWIFTSnapshotWriter(object):
"""
The SWIFT writer dataset. This is used to store all particle arrays and do
some extra processing before writing a HDF5 file containing:
Expand All @@ -516,7 +516,7 @@ def __init__(
scale_factor: np.float32 = 1.0,
):
"""
Creates SWIFTWriterDataset object
Creates SWIFTSnapshotWriter object
Parameters
----------
Expand Down
7 changes: 3 additions & 4 deletions swiftsimio/subset_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,13 @@ def get_swift_name(name: str) -> str:
str
SWIFT particle type corresponding to `name` (e.g. PartType0)
"""
part_type_nums = [
part_type_names = [
k for k, v in metadata.particle_types.particle_name_underscores.items()
]
part_types = [
v for k, v in metadata.particle_types.particle_name_underscores.items()
]
part_type_num = part_type_nums[part_types.index(name)]
return f"PartType{part_type_num}"
return part_type_names[part_types.index(name)]


def get_dataset_mask(
Expand Down Expand Up @@ -66,7 +65,7 @@ def get_dataset_mask(
suffix = "" if suffix is None else suffix

if "PartType" in dataset_name:
part_type = [int(x) for x in filter(str.isdigit, dataset_name)][0]
part_type = dataset_name.lstrip("/").split("/")[0]
mask_name = metadata.particle_types.particle_name_underscores[part_type]
return getattr(mask, f"{mask_name}{suffix}", None)
else:
Expand Down
1 change: 0 additions & 1 deletion swiftsimio/visualisation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,3 @@
from .slice import slice_scatter as slice
from .slice import slice_gas, slice_gas_pixel_grid
from .smoothing_length import generate_smoothing_lengths

22 changes: 11 additions & 11 deletions swiftsimio/visualisation/power_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from swiftsimio.optional_packages import tqdm
from swiftsimio.accelerated import jit, NUM_THREADS, prange
from swiftsimio import cosmo_array
from swiftsimio.reader import __SWIFTParticleDataset
from swiftsimio.reader import __SWIFTGroupDatasets

from typing import Optional, Dict, Tuple

Expand Down Expand Up @@ -169,7 +169,7 @@ def deposit_parallel(


def render_to_deposit(
data: __SWIFTParticleDataset,
data: __SWIFTGroupDatasets,
resolution: int,
project: str = "masses",
folding: int = 0,
Expand Down Expand Up @@ -199,7 +199,7 @@ def render_to_deposit(
"""

# Get the positions and masses
folding = 2.0**folding
folding = 2.0 ** folding
positions = data.coordinates
quantity = getattr(data, project)

Expand Down Expand Up @@ -244,10 +244,10 @@ def render_to_deposit(
units = 1.0 / (
data.metadata.boxsize[0] * data.metadata.boxsize[1] * data.metadata.boxsize[2]
)
units.convert_to_units(1.0 / data.metadata.boxsize.units**3)
units.convert_to_units(1.0 / data.metadata.boxsize.units ** 3)

units *= quantity.units
new_cosmo_factor = quantity.cosmo_factor / (coord_cosmo_factor**3)
new_cosmo_factor = quantity.cosmo_factor / (coord_cosmo_factor ** 3)

return cosmo_array(
deposition, comoving=comoving, cosmo_factor=new_cosmo_factor, units=units
Expand Down Expand Up @@ -399,7 +399,7 @@ def folded_depositions_to_power_spectrum(

if folding != final_folding:
cutoff_wavenumber = (
2.0**folding * np.min(depositions[folding].shape) / np.min(box_size)
2.0 ** folding * np.min(depositions[folding].shape) / np.min(box_size)
)

if cutoff_above_wavenumber_fraction is not None:
Expand All @@ -424,7 +424,7 @@ def folded_depositions_to_power_spectrum(
corrected_wavenumber_centers[prefer_bins] = folded_wavenumber_centers[
prefer_bins
].to(corrected_wavenumber_centers.units)
folding_tracker[prefer_bins] = 2.0**folding
folding_tracker[prefer_bins] = 2.0 ** folding

contributed_counts[prefer_bins] = folded_counts[prefer_bins]
elif transition == "average":
Expand Down Expand Up @@ -457,7 +457,7 @@ def folded_depositions_to_power_spectrum(

# For debugging, we calculate an effective fold number.
folding_tracker[use_bins] = (
(folding_tracker * existing_weight + (2.0**folding) * new_weight)
(folding_tracker * existing_weight + (2.0 ** folding) * new_weight)
/ transition_norm
)[use_bins]

Expand Down Expand Up @@ -538,7 +538,7 @@ def deposition_to_power_spectrum(
deposition.shape == cross_deposition.shape
), "Depositions must have the same shape"

folding = 2.0**folding
folding = 2.0 ** folding

box_size_folded = box_size[0] / folding
npix = deposition.shape[0]
Expand All @@ -560,7 +560,7 @@ def deposition_to_power_spectrum(
else:
conj_fft = fft.conj()

fourier_amplitudes = (fft * conj_fft).real * box_size_folded**3
fourier_amplitudes = (fft * conj_fft).real * box_size_folded ** 3

# Calculate k-value spacing (centered FFT)
dk = 2 * np.pi / (box_size_folded)
Expand Down Expand Up @@ -592,7 +592,7 @@ def deposition_to_power_spectrum(
divisor[zero_mask] = 1

# Correct for folding
binned_amplitudes *= folding**3
binned_amplitudes *= folding ** 3

# Correct units and names
wavenumbers = unyt.unyt_array(
Expand Down
8 changes: 2 additions & 6 deletions swiftsimio/visualisation/smoothing_length/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from .sph import get_hsml as sph_get_hsml
from .nearest_neighbours import (
get_hsml as nearest_neighbours_get_hsml,
)
from .generate import (
generate_smoothing_lengths,
)
from .nearest_neighbours import get_hsml as nearest_neighbours_get_hsml
from .generate import generate_smoothing_lengths

backends_get_hsml = {
"sph": sph_get_hsml,
Expand Down
2 changes: 1 addition & 1 deletion tests/subset_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def compare_data_contents(A, B):
A_type = getattr(A, part_type)
B_type = getattr(B, part_type)
particle_dataset_field_names = set(
A_type.particle_metadata.field_names + B_type.particle_metadata.field_names
A_type.group_metadata.field_names + B_type.group_metadata.field_names
)

for attr in particle_dataset_field_names:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def test_units(filename):

# Now need to extract the particle paths in the original hdf5 file
# for comparison...
paths = numpy_array(field.particle_metadata.field_paths)
names = numpy_array(field.particle_metadata.field_names)
paths = numpy_array(field.group_metadata.field_paths)
names = numpy_array(field.group_metadata.field_names)

for property in properties:
# Read the 0th element, and compare in CGS units.
Expand Down
24 changes: 12 additions & 12 deletions tests/test_extraparts.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def test_write():
)
# Specify a new type in the metadata - currently done by editing the dictionaries directly.
# TODO: Remove this terrible way of setting up different particle types.
swp.particle_name_underscores[6] = "extratype"
swp.particle_name_class[6] = "Extratype"
swp.particle_name_text[6] = "Extratype"
swp.particle_name_underscores["PartType7"] = "extratype"
swp.particle_name_class["PartType7"] = "Extratype"
swp.particle_name_text["PartType7"] = "Extratype"

swmw.extratype = {"smoothing_length": "SmoothingLength", **swmw.shared}

Expand All @@ -110,19 +110,19 @@ def test_write():
x.write("extra_test.hdf5")

# Clean up these global variables we screwed around with...
swp.particle_name_underscores.pop(6)
swp.particle_name_class.pop(6)
swp.particle_name_text.pop(6)
swp.particle_name_underscores.pop("PartType7")
swp.particle_name_class.pop("PartType7")
swp.particle_name_text.pop("PartType7")


def test_read():
"""
Tests whether swiftsimio can handle a new particle type. Has a few asserts to check the
data is read in correctly.
"""
swp.particle_name_underscores[6] = "extratype"
swp.particle_name_class[6] = "Extratype"
swp.particle_name_text[6] = "Extratype"
swp.particle_name_underscores["PartType7"] = "extratype"
swp.particle_name_class["PartType7"] = "Extratype"
swp.particle_name_text["PartType7"] = "Extratype"

swmw.extratype = {"smoothing_length": "SmoothingLength", **swmw.shared}

Expand All @@ -136,6 +136,6 @@ def test_read():
os.remove("extra_test.hdf5")

# Clean up these global variables we screwed around with...
swp.particle_name_underscores.pop(6)
swp.particle_name_class.pop(6)
swp.particle_name_text.pop(6)
swp.particle_name_underscores.pop("PartType7")
swp.particle_name_class.pop("PartType7")
swp.particle_name_text.pop("PartType7")

0 comments on commit 6a1675b

Please sign in to comment.