Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/issue 60 #61

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions concatenator/attribute_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,32 @@ def regroup_coordinate_attribute(attribute_string: str) -> str:
"""
# Use the separator that's in the attribute string only if all separators in the string are the same.
# Otherwise, we will use our own default separator.
whitespaces = re.findall(r'\s+', attribute_string)
whitespaces = re.findall(r"\s+", attribute_string)
if len(set(whitespaces)) <= 1:
new_sep = whitespaces[0]
else:
new_sep = COORD_DELIM

return new_sep.join(
'/'.join(c.split(GROUP_DELIM))[1:]
for c
in attribute_string.split() # split on any whitespace
"/".join(c.split(GROUP_DELIM))[1:]
for c in attribute_string.split() # split on any whitespace
)


def flatten_coordinate_attribute_paths(dataset: netCDF4.Dataset,
var: netCDF4.Variable,
variable_name: str) -> None:
def flatten_coordinate_attribute_paths(
dataset: netCDF4.Dataset, var: netCDF4.Variable, variable_name: str
) -> None:
"""Flatten the paths of variables referenced in the coordinates attribute."""
if 'coordinates' in var.ncattrs():
coord_att = var.getncattr('coordinates')
if "coordinates" in var.ncattrs():
coord_att = var.getncattr("coordinates")

new_coord_att = _flatten_coordinate_attribute(coord_att)

dataset.variables[variable_name].setncattr('coordinates', new_coord_att)
dataset.variables[variable_name].setncattr("coordinates", new_coord_att)


def _flatten_coordinate_attribute(attribute_string: str) -> str:
"""Converts attributes that specify group membership via "/" to use new group delimiter, even for the root level.
"""Converts attributes with "/" delimiters to use new group delimiter, even for the root level.

Examples
--------
Expand All @@ -73,15 +72,18 @@ def _flatten_coordinate_attribute(attribute_string: str) -> str:
"""
# Use the separator that's in the attribute string only if all separators in the string are the same.
# Otherwise, we will use our own default separator.
whitespaces = re.findall(r'\s+', attribute_string)
if len(set(whitespaces)) <= 1:
whitespaces = re.findall(r"\s+", attribute_string)
if len(set(whitespaces)) == 1:
new_sep = whitespaces[0]
else:
new_sep = COORD_DELIM

# A new string is constructed.
return new_sep.join(
f'{GROUP_DELIM}{c.replace("/", GROUP_DELIM)}'
for c
in attribute_string.split() # split on any whitespace
)
return new_sep.join(flatten_variable_path_str(item) for item in attribute_string.split())


def flatten_variable_path_str(path_str: str) -> str:
"""Converts a path with "/" delimiters to use new group delimiter, even for the root level."""
new_path = path_str.replace("/", GROUP_DELIM)

return f"{GROUP_DELIM}{new_path}" if not new_path.startswith(GROUP_DELIM) else new_path
43 changes: 37 additions & 6 deletions concatenator/stitchee.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import xarray as xr

from concatenator import GROUP_DELIM
from concatenator.attribute_handling import flatten_variable_path_str
from concatenator.dimension_cleanup import remove_duplicate_dims
from concatenator.file_ops import add_label_to_path
from concatenator.group_handling import (
Expand All @@ -27,6 +28,7 @@ def stitchee(
concat_method: str = "xarray-concat",
concat_dim: str = "",
concat_kwargs: dict | None = None,
variables_to_include: list[str] | None = None,
logger: Logger = default_logger,
) -> str:
"""Concatenate netCDF data files along an existing dimension.
Expand All @@ -35,8 +37,16 @@ def stitchee(
----------
files_to_concat : list[str]
output_file : str
keep_tmp_files : bool
write_tmp_flat_concatenated : bool, optional
keep_tmp_files : bool, optional
concat_method : str, optional
Either 'xarray-concat' or 'xarray-combine'
concat_dim : str, optional
concat_kwargs : dict, optional
Keyword arguments to pass through to the xarray concatenation method
variables_to_include : list[str], optional
Names of variables to include. All other variables are excluded from the result

logger : logging.Logger

Returns
Expand All @@ -59,6 +69,14 @@ def stitchee(
"'concat_dim' was specified, but will not be used because xarray-combine method was selected."
)

# Convert variable names inputted to flattened versions
if variables_to_include is not None:
variables_to_include_flattened = [
flatten_variable_path_str(v) for v in variables_to_include
]
else:
variables_to_include_flattened = None

logger.info("Flattening all input files...")
xrdataset_list = []

Expand All @@ -67,10 +85,21 @@ def stitchee(
# The group structure is flattened.
start_time = time.time()
logger.info(" ..file %03d/%03d <%s>..", i + 1, num_input_files, filepath)
flat_dataset, coord_vars, _ = flatten_grouped_dataset(
flat_dataset, coord_vars, string_vars = flatten_grouped_dataset(
nc.Dataset(filepath, "r"), filepath, ensure_all_dims_are_coords=True
)

if variables_to_include_flattened is not None:
variables_to_delete = [
var_name
for var_name, _ in flat_dataset.variables.items()
if (var_name not in variables_to_include_flattened)
and (var_name not in coord_vars)
]

for var_name in variables_to_delete:
del flat_dataset.variables[var_name]

logger.info("Removing duplicate dimensions")
flat_dataset = remove_duplicate_dims(flat_dataset)

Expand Down Expand Up @@ -101,22 +130,24 @@ def stitchee(
# coords='minimal',
# compat='override')

# Establish default concatenation keyword arguments if not supplied as input.
if concat_kwargs is None:
concat_kwargs = {}
if "data_vars" not in concat_kwargs:
concat_kwargs["data_vars"] = "minimal"
if "coords" not in concat_kwargs:
concat_kwargs["coords"] = "minimal"

# Perform concatenation operation.
if concat_method == "xarray-concat":
combined_ds = xr.concat(
xrdataset_list,
dim=GROUP_DELIM + concat_dim,
data_vars="minimal",
coords="minimal",
**concat_kwargs,
)
elif concat_method == "xarray-combine":
combined_ds = xr.combine_by_coords(
xrdataset_list,
data_vars="minimal",
coords="minimal",
**concat_kwargs,
)
else:
Expand Down
23 changes: 18 additions & 5 deletions tests/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytest

from concatenator import concat_with_nco
from concatenator.attribute_handling import flatten_variable_path_str
from concatenator.stitchee import stitchee


Expand All @@ -37,13 +38,14 @@ def run_verification_with_stitchee(
concat_method: str = "xarray-concat",
record_dim_name: str = "mirror_step",
concat_kwargs: dict | None = None,
variables_to_include: list[str] | None = None,
):
output_path = str(self.__output_path.joinpath(output_name)) # type: ignore
data_path = self.__test_data_path.joinpath(data_dir) # type: ignore

input_files = []
for filepath in data_path.iterdir():
if Path(filepath).suffix.lower() in (".nc", ".h5", ".hdf"):
if Path(filepath).suffix.lower() in (".nc", ".nc4", ".h5", ".hdf"):
copied_input_new_path = self.__output_path / Path(filepath).name # type: ignore
shutil.copyfile(filepath, copied_input_new_path)
input_files.append(str(copied_input_new_path))
Expand All @@ -59,16 +61,27 @@ def run_verification_with_stitchee(
concat_method=concat_method,
concat_dim=record_dim_name,
concat_kwargs=concat_kwargs,
variables_to_include=variables_to_include,
)

merged_dataset = nc.Dataset(output_path)

# Verify that the length of the record dimension in the concatenated file equals
# the sum of the lengths across the input files
length_sum = 0
for file in input_files:
length_sum += len(nc.Dataset(file).variables[record_dim_name])
assert length_sum == len(merged_dataset.variables[record_dim_name])
with nc.Dataset(file) as ds:
length_sum += ds.dimensions[flatten_variable_path_str(record_dim_name)].size

with nc.Dataset(output_path) as merged_dataset:
if record_dim_name in merged_dataset.variables:
# Primary dimension is a root level variable
assert length_sum == len(merged_dataset.variables[record_dim_name])
elif record_dim_name in merged_dataset.dimensions:
# Primary dimension is a root level dimension, but not a variable
assert length_sum == merged_dataset.dimensions[record_dim_name].size
else:
raise AttributeError(
"Unexpected condition, where primary record dimension is not at the root level."
)

def run_verification_with_nco(self, data_dir, output_name, record_dim_name="mirror_step"):
output_path = str(self.__output_path.joinpath(output_name))
Expand Down
60 changes: 47 additions & 13 deletions tests/test_group_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,61 @@

# pylint: disable=C0116, C0301

from concatenator.attribute_handling import (_flatten_coordinate_attribute,
regroup_coordinate_attribute)
from concatenator.attribute_handling import (
_flatten_coordinate_attribute,
regroup_coordinate_attribute,
)


def test_coordinate_attribute_flattening():
def test_coordinate_attribute_flattening_with_no_leading_slash():
# Case with groups present and double spaces.
assert _flatten_coordinate_attribute(
"Time_and_Position/time Time_and_Position/instrument_fov_latitude Time_and_Position/instrument_fov_longitude"
) == '__Time_and_Position__time __Time_and_Position__instrument_fov_latitude __Time_and_Position__instrument_fov_longitude'
assert (
_flatten_coordinate_attribute(
"Time_and_Position/time Time_and_Position/instrument_fov_latitude Time_and_Position/instrument_fov_longitude"
)
== "__Time_and_Position__time __Time_and_Position__instrument_fov_latitude __Time_and_Position__instrument_fov_longitude"
)

# Case with NO groups present and single spaces.
assert _flatten_coordinate_attribute(
"time longitude latitude ozone_profile_pressure ozone_profile_altitude"
) == "__time __longitude __latitude __ozone_profile_pressure __ozone_profile_altitude"
assert (
_flatten_coordinate_attribute(
"time longitude latitude ozone_profile_pressure ozone_profile_altitude"
)
== "__time __longitude __latitude __ozone_profile_pressure __ozone_profile_altitude"
)


def test_coordinate_attribute_flattening_with_a_leading_slash():
# Case with groups present and double spaces.
assert (
_flatten_coordinate_attribute(
"/Time_and_Position/time /Time_and_Position/instrument_fov_latitude /Time_and_Position/instrument_fov_longitude"
)
== "__Time_and_Position__time __Time_and_Position__instrument_fov_latitude __Time_and_Position__instrument_fov_longitude"
)

# Case with NO groups present and single spaces.
assert (
_flatten_coordinate_attribute(
"/time /longitude /latitude /ozone_profile_pressure /ozone_profile_altitude"
)
== "__time __longitude __latitude __ozone_profile_pressure __ozone_profile_altitude"
)


def test_coordinate_attribute_regrouping():
# Case with groups present and double spaces.
assert regroup_coordinate_attribute(
'__Time_and_Position__time __Time_and_Position__instrument_fov_latitude __Time_and_Position__instrument_fov_longitude') == "Time_and_Position/time Time_and_Position/instrument_fov_latitude Time_and_Position/instrument_fov_longitude"
assert (
regroup_coordinate_attribute(
"__Time_and_Position__time __Time_and_Position__instrument_fov_latitude __Time_and_Position__instrument_fov_longitude"
)
== "Time_and_Position/time Time_and_Position/instrument_fov_latitude Time_and_Position/instrument_fov_longitude"
)

# Case with NO groups present and single spaces.
assert regroup_coordinate_attribute(
"__time __longitude __latitude __ozone_profile_pressure __ozone_profile_altitude") == "time longitude latitude ozone_profile_pressure ozone_profile_altitude"
assert (
regroup_coordinate_attribute(
"__time __longitude __latitude __ozone_profile_pressure __ozone_profile_altitude"
)
== "time longitude latitude ozone_profile_pressure ozone_profile_altitude"
)