Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support non-ndarray computations, cache unit calls, and add slots to DataArray #47

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sympl/_core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


class DataArray(xr.DataArray):
__slots__ = []
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the implications of setting this? What warning is it suppressing, and what behavior does it cause when you set this to an empty list?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is aimed to suppress FutureWarning: xarray subclass DataArray should explicitly define __slots__. Here is a nice explanation of how __slots__ work.


def __add__(self, other):
"""If this DataArray is on the left side of the addition, keep its
Expand Down
9 changes: 6 additions & 3 deletions sympl/_core/get_np_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@ def get_numpy_array(data_array, out_dims, dim_lengths):
dict of dim_lengths that will give the length of any missing dims in the
data_array.
"""
if len(data_array.values.shape) == 0 and len(out_dims) == 0:
return data_array.values # special case, 0-dimensional scalar array
if len(data_array.data.shape) == 0 and len(out_dims) == 0:
Copy link
Owner

@mcgibbon mcgibbon Aug 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change alters the behavior of this function, which is OK, but the variable names, function name, file name, and docstring need to be updated. For example, I would suggest naming the function something like get_underlying_data.

The tests in test_get_restore_numpy_array.py should also be updated to cover cases where data is not a numpy array (and have similar re-namings).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct. I will rename the function and update the tests.

return data_array.data # special case, 0-dimensional scalar array
else:
missing_dims = [dim for dim in out_dims if dim not in data_array.dims]
for dim in missing_dims:
data_array = data_array.expand_dims(dim)
numpy_array = data_array.transpose(*out_dims).values
if not all(dim1 == dim2 for dim1, dim2 in zip(data_array.dims, out_dims)):
numpy_array = data_array.transpose(*out_dims).data
else:
numpy_array = data_array.data
if len(missing_dims) == 0:
out_array = numpy_array
else: # expand out missing dims which are currently length 1.
Expand Down
7 changes: 4 additions & 3 deletions sympl/_core/restore_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@


def ensure_values_are_arrays(array_dict):
for name, value in array_dict.items():
if not isinstance(value, np.ndarray):
array_dict[name] = np.asarray(value)
pass
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a temporary and dirty solution. We could think of a mechanism to control whether arrays must be coerced or not.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't merge this change as-is, what problem is being solved here and what other solutions are available for it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The asarray function of Numpy seeks to coerce the input array-like storage value into a ndarray. This operation could break e.g. the data layout and the memory alignment of value. In the specific case of Tasmania, value could be a GT4Py storage, whose low-level details and features are fitted to the target computing architecture and thus must be preserved. The problem can be circumvented by monkey-patching Numpy via the function gt4py.storage.prepare_numpy(), but this is much GT4Py-specific. We could think of a more organic solution, or just pass value to DataArray as it is and let DataArray perform all type checks (and eventually throw exceptions).

# for name, value in array_dict.items():
# if not isinstance(value, np.ndarray):
# array_dict[name] = np.asarray(value)


def get_alias_or_name(name, output_properties, input_properties):
Expand Down
43 changes: 22 additions & 21 deletions sympl/_core/units.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
# -*- coding: utf-8 -*-
import functools
import pint


class UnitRegistry(pint.UnitRegistry):

@functools.lru_cache
def __call__(self, input_string, **kwargs):
return super(UnitRegistry, self).__call__(
input_string.replace(
u'%', 'percent').replace(
u'°', 'degree'
),
**kwargs)
input_string.replace(u"%", "percent").replace(u"°", "degree"), **kwargs
)


unit_registry = UnitRegistry()
unit_registry.define('degrees_north = degree_north = degree_N = degrees_N = degreeN = degreesN')
unit_registry.define('degrees_east = degree_east = degree_E = degrees_E = degreeE = degreesE')
unit_registry.define('percent = 0.01*count = %')
unit_registry.define(
"degrees_north = degree_north = degree_N = degrees_N = degreeN = degreesN"
)
unit_registry.define(
"degrees_east = degree_east = degree_E = degrees_E = degreeE = degreesE"
)
unit_registry.define("percent = 0.01*count = %")


def units_are_compatible(unit1, unit2):
Expand Down Expand Up @@ -63,9 +65,7 @@ def clean_units(unit_string):

def is_valid_unit(unit_string):
"""Returns True if the unit string is recognized, and False otherwise."""
unit_string = unit_string.replace(
'%', 'percent').replace(
'°', 'degree')
unit_string = unit_string.replace("%", "percent").replace("°", "degree")
try:
unit_registry(unit_string)
except pint.UndefinedUnitError:
Expand All @@ -75,16 +75,17 @@ def is_valid_unit(unit_string):


def data_array_to_units(value, units):
if not hasattr(value, 'attrs') or 'units' not in value.attrs:
raise TypeError(
'Cannot retrieve units from type {}'.format(type(value)))
elif unit_registry(value.attrs['units']) != unit_registry(units):
attrs = value.attrs.copy()
value = unit_registry.Quantity(value, value.attrs['units']).to(units).magnitude
attrs['units'] = units
value.attrs = attrs
if not hasattr(value, "attrs") or "units" not in value.attrs:
raise TypeError("Cannot retrieve units from type {}".format(type(value)))
elif unit_registry(value.attrs["units"]) != unit_registry(units):
out = value.copy()
out.data[...] = (
unit_registry.convert(1, value.attrs["units"], units) * value.data
)
out.attrs["units"] = units
value = out
return value


def from_unit_to_another(value, original_units, new_units):
return (unit_registry(original_units)*value).to(new_units).magnitude
return (unit_registry(original_units) * value).to(new_units).magnitude