Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use only a single array for masking #200

Merged
merged 2 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 116 additions & 47 deletions swiftsimio/masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,79 @@ def __init__(self, metadata: SWIFTMetadata, spatial_only=True):
if not spatial_only:
self._generate_empty_masks()

def _generate_empty_masks(self):

def _generate_mapping_dictionary(self) -> dict[str, str]:
"""
Generates the empty (i.e. all False) masks for all available particle
types.
Creates cross-links between 'group names' and their underlying cell metadata
names. Allows for pointers to be used instead of re-creating masks.
"""

for group_name in self.metadata.present_group_names:
if self.metadata.shared_cell_counts is None:
# Each and every particle type has its own cell counts, offsets,
# and hence masks.
self.group_mapping = {
group: f"_{group}" for group in self.metadata.present_group_names
}
else:
# We actually only have _one_ mask!
self.group_mapping = {
group: "_shared" for group in self.metadata.present_group_names
}

return self.group_mapping

def _generate_update_list(self) -> list[str]:
"""
Gets a list of internal mask variables that need to be updated when
we change the spatial mask.
"""

if self.metadata.shared_cell_counts is None:
# Each and every particle type has its own cell counts, offsets,
# and hence masks.
return [f"_{group}" for group in self.metadata.present_group_names]
else:
# We actually only have _one_ mask!
return ["_shared"]

def _create_pointers(self):
# Create pointers for every single particle type.
for group_name, data_name in self._generate_mapping_dictionary().items():
setattr(
self,
group_name,
np.ones(getattr(self.metadata, f"n_{group_name}"), dtype=bool),
getattr(self, data_name)
)

setattr(
self,
f"{group_name}_size",
getattr(self, f"{data_name}_size")
)


def _generate_empty_masks(self):
"""
Generates the empty (i.e. all False) masks for all available particle
types.
"""

mapping = self._generate_mapping_dictionary()

if self.metadata.shared_cell_counts is not None:
size = getattr(self.metadata, f"n_{self.metadata.shared_cell_counts.lower()}")
self._shared = np.ones(size, dtype=bool)
self._shared_size = size

else:
# Create empty masks for each and every particle type.
for group_name, data_name in mapping.items():
size = getattr(self.metadata, f"n_{group_name}")
setattr(self, data_name, np.ones(size, dtype=bool))
setattr(self, f"{data_name}_size", size)

self._create_pointers()

return

def _unpack_cell_metadata(self):
Expand All @@ -104,29 +164,39 @@ def _unpack_cell_metadata(self):
# file i/o implemented
offset_handle = cell_handle["Offsets"]


if self.metadata.shared_cell_counts is not None:
# Single - called _shared.
self.offsets["shared"] = offset_handle[self.metadata.shared_cell_counts][:]
self.counts["shared"] = count_handle[self.metadata.shared_cell_counts][:]
else:
for group, group_name in zip(
self.metadata.present_groups, self.metadata.present_group_names
):
counts = count_handle[group][:]
offsets = offset_handle[group][:]

self.offsets[group_name] = offset_handle[group][:]
self.counts[group_name] = count_handle[group][:]

# Only want to compute this once (even if it is fast, we do not
# have a reliable stable sort in the case where cells do not
# contain at least one of each type of particle).
sort = None

for group, group_name in zip(
self.metadata.present_groups, self.metadata.present_group_names
):
if self.metadata.shared_cell_counts is None:
counts = count_handle[group][:]
offsets = offset_handle[group][:]
else:
counts = count_handle[self.metadata.shared_cell_counts][:]
offsets = offset_handle[self.metadata.shared_cell_counts][:]
# Now perform sort:
for key in self.offsets.keys():
offsets = self.offsets[key]
counts = self.counts[key]

# When using MPI, we cannot assume that these are sorted.
if sort is None:
# Only compute once; not stable between particle
# types if some datasets do not have particles in a cell!
sort = np.argsort(offsets)

self.offsets[group_name] = offsets[sort]
self.counts[group_name] = counts[sort]
self.offsets[key] = offsets[sort]
self.counts[key] = counts[sort]

# Also need to sort centers in the same way
self.centers = unyt.unyt_array(centers_handle[:][sort], units=self.units.length)
Expand Down Expand Up @@ -180,8 +250,11 @@ def constrain_mask(
print("You cannot constrain a mask if spatial_only=True")
print("Please re-initialise the SWIFTMask object with spatial_only=False")
return

mapping = self._generate_mapping_dictionary()
data_name = mapping[group_name]

current_mask = getattr(self, group_name)
current_mask = getattr(self, data_name)

group_metadata = getattr(self.metadata, f"{group_name}_properties")
unit_dict = {
Expand Down Expand Up @@ -209,7 +282,7 @@ def constrain_mask(

current_mask[current_mask] = new_mask

setattr(self, group_name, current_mask)
setattr(self, data_name, current_mask)

return

Expand Down Expand Up @@ -288,7 +361,7 @@ def _generate_cell_mask(self, restrict):

return cell_mask

def _update_spatial_mask(self, restrict, group_name: str, cell_mask: np.array):
def _update_spatial_mask(self, restrict, data_name: str, cell_mask: np.array):
"""
Updates the particle mask using the cell mask.

Expand All @@ -302,28 +375,30 @@ def _update_spatial_mask(self, restrict, group_name: str, cell_mask: np.array):
restrict : list
currently unused

group_name : str
particle type to update
data_name : str
underlying data to update (e.g. _gas, _shared)

cell_mask : np.array
cell mask used to update the particle mask
"""

count_name = data_name[1:] # Remove the underscore

if self.spatial_only:
counts = self.counts[group_name][cell_mask]
offsets = self.offsets[group_name][cell_mask]
counts = self.counts[count_name][cell_mask]
offsets = self.offsets[count_name][cell_mask]

this_mask = [[o, c + o] for c, o in zip(counts, offsets)]

setattr(self, group_name, np.array(this_mask))
setattr(self, f"{group_name}_size", np.sum(counts))
setattr(self, data_name, np.array(this_mask))
setattr(self, f"{data_name}_size", np.sum(counts))

else:
counts = self.counts[group_name][~cell_mask]
offsets = self.offsets[group_name][~cell_mask]
counts = self.counts[count_name][~cell_mask]
offsets = self.offsets[count_name][~cell_mask]

# We must do the whole boolean mask business.
this_mask = getattr(self, group_name)
this_mask = getattr(self, data_name)

for count, offset in zip(counts, offsets):
this_mask[offset : count + offset] = False
Expand Down Expand Up @@ -373,8 +448,10 @@ def constrain_spatial(self, restrict, intersect: bool = False):
# we just make a new mask
self.cell_mask = self._generate_cell_mask(restrict)

for group_name in self.metadata.present_group_names:
self._update_spatial_mask(restrict, group_name, self.cell_mask)
for mask in self._generate_update_list():
self._update_spatial_mask(restrict, mask, self.cell_mask)

self._create_pointers()

return

Expand All @@ -388,27 +465,19 @@ def convert_masks_to_ranges(self):
If you don't know what you are doing please don't use this.
"""

if self.spatial_only:
# We are already done!
return
else:
# Spatial only already comes like this!
if not self.spatial_only:
# We must do the whole boolean mask stuff. To do that, we
# First, convert each boolean mask into an integer mask
# Use the accelerate.ranges_from_array function to convert
# This into a set of ranges.
for mask in self._generate_update_list():
where_array = np.where(getattr(self, mask))[0]
setattr(self, f"{mask}_size", where_array.size)
print(mask, where_array)
setattr(self, mask, ranges_from_array(where_array))

for group_name in self.metadata.present_group_names:
setattr(
self,
group_name,
# Because it nests things in a list for some reason.
np.where(getattr(self, group_name))[0],
)

setattr(self, f"{group_name}_size", getattr(self, group_name).size)

for group_name in self.metadata.present_group_names:
setattr(self, group_name, ranges_from_array(getattr(self, group_name)))
self._create_pointers()

return

Expand All @@ -431,7 +500,7 @@ def constrain_index(self, index: int):
setattr(self, f"{group_name}_size", 1)
return

def get_masked_counts_offsets(self) -> (Dict[str, np.array], Dict[str, np.array]):
def get_masked_counts_offsets(self) -> tuple[dict[str, np.array], dict[str, np.array]]:
"""
Returns the particle counts and offsets in cells selected by the mask

Expand Down
4 changes: 4 additions & 0 deletions swiftsimio/metadata/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,13 +1272,17 @@ def __init__(self, filename: str, units: SWIFTUnits):

self.get_metadata()
self.postprocess_header()
self.unpack_subhalo_number()

self.load_groups()

# After we've loaded all this metadata, we can safely release the file handle.
self.handle.close()

return

def unpack_subhalo_number(self):
self.n_subhalos = int(self.num_subhalo[0])

@property
def present_groups(self):
Expand Down
30 changes: 29 additions & 1 deletion tests/test_soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,39 @@

from tests.helper import requires

from swiftsimio import load
from swiftsimio import load, mask


@requires("soap_example.hdf5")
def test_soap_can_load(filename):
data = load(filename)

return


@requires("soap_example.hdf5")
def test_soap_can_mask_spatial(filename):
this_mask = mask(filename, spatial_only=True)

bs = this_mask.metadata.boxsize
this_mask.constrain_spatial(
[[0 * b, 0.5 * b] for b in bs]
)

data = load(filename, mask=this_mask)

data.spherical_overdensity_200_mean.total_mass[0]


@requires("soap_example.hdf5")
def test_soap_can_mask_non_spatial(filename):
this_mask = mask(filename, spatial_only=False)

bs = this_mask.metadata.boxsize
this_mask.constrain_spatial(
[[0 * b, 0.5 * b] for b in bs]
)

data = load(filename, mask=this_mask)

data.spherical_overdensity_200_mean.total_mass[0]
Loading