From 9d8ec5be8948e2e2da53a3de0572bdaef14e81de Mon Sep 17 00:00:00 2001 From: Stefano Ubbiali Date: Mon, 28 Oct 2019 09:52:32 +0100 Subject: [PATCH 1/6] Transpose fields only when really needed. --- sympl/_core/get_np_arrays.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sympl/_core/get_np_arrays.py b/sympl/_core/get_np_arrays.py index 9a84b57..2668b75 100644 --- a/sympl/_core/get_np_arrays.py +++ b/sympl/_core/get_np_arrays.py @@ -49,7 +49,10 @@ def get_numpy_array(data_array, out_dims, dim_lengths): missing_dims = [dim for dim in out_dims if dim not in data_array.dims] for dim in missing_dims: data_array = data_array.expand_dims(dim) - numpy_array = data_array.transpose(*out_dims).values + if not all(dim1 == dim2 for dim1, dim2 in zip(data_array.dims, out_dims)): + numpy_array = data_array.transpose(*out_dims).values + else: + numpy_array = data_array.values if len(missing_dims) == 0: out_array = numpy_array else: # expand out missing dims which are currently length 1. From 12327bb0d6c0d548853cc8541f7c43819173a7d6 Mon Sep 17 00:00:00 2001 From: Stefano Ubbiali Date: Tue, 17 Dec 2019 12:25:29 +0100 Subject: [PATCH 2/6] Added __slots__ to DataArray class. --- sympl/_core/dataarray.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sympl/_core/dataarray.py b/sympl/_core/dataarray.py index b6b3f84..933786b 100644 --- a/sympl/_core/dataarray.py +++ b/sympl/_core/dataarray.py @@ -4,6 +4,7 @@ class DataArray(xr.DataArray): + __slots__ = [] def __add__(self, other): """If this DataArray is on the left side of the addition, keep its From be38d6572175cee0950c7bb8a97b4f5f6975fd6a Mon Sep 17 00:00:00 2001 From: Stefano Ubbiali Date: Fri, 1 May 2020 15:46:02 +0200 Subject: [PATCH 3/6] Use ``data`` rather than ``values`` property of DataArrays. --- sympl/_core/get_np_arrays.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sympl/_core/get_np_arrays.py b/sympl/_core/get_np_arrays.py index 2668b75..1cb98d4 100644 --- a/sympl/_core/get_np_arrays.py +++ b/sympl/_core/get_np_arrays.py @@ -43,16 +43,16 @@ def get_numpy_array(data_array, out_dims, dim_lengths): dict of dim_lengths that will give the length of any missing dims in the data_array. """ - if len(data_array.values.shape) == 0 and len(out_dims) == 0: - return data_array.values # special case, 0-dimensional scalar array + if len(data_array.data.shape) == 0 and len(out_dims) == 0: + return data_array.data # special case, 0-dimensional scalar array else: missing_dims = [dim for dim in out_dims if dim not in data_array.dims] for dim in missing_dims: data_array = data_array.expand_dims(dim) if not all(dim1 == dim2 for dim1, dim2 in zip(data_array.dims, out_dims)): - numpy_array = data_array.transpose(*out_dims).values + numpy_array = data_array.transpose(*out_dims).data else: - numpy_array = data_array.values + numpy_array = data_array.data if len(missing_dims) == 0: out_array = numpy_array else: # expand out missing dims which are currently length 1. From 56620eab8c5ebd0d18984c3f928b5eb304030053 Mon Sep 17 00:00:00 2001 From: Stefano Ubbiali Date: Fri, 1 May 2020 15:47:16 +0200 Subject: [PATCH 4/6] Do not coerce raw storages to np.ndarrays. --- sympl/_core/restore_dataarray.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sympl/_core/restore_dataarray.py b/sympl/_core/restore_dataarray.py index 791fba1..cf87e90 100644 --- a/sympl/_core/restore_dataarray.py +++ b/sympl/_core/restore_dataarray.py @@ -8,9 +8,10 @@ def ensure_values_are_arrays(array_dict): - for name, value in array_dict.items(): - if not isinstance(value, np.ndarray): - array_dict[name] = np.asarray(value) + pass + # for name, value in array_dict.items(): + # if not isinstance(value, np.ndarray): + # array_dict[name] = np.asarray(value) def get_alias_or_name(name, output_properties, input_properties): From 9ddcebdd5549b0fe0458b7550a74b15bdb52999e Mon Sep 17 00:00:00 2001 From: Stefano Ubbiali Date: Thu, 15 Oct 2020 08:45:26 +0200 Subject: [PATCH 5/6] Avoid relying on unit_registry to check if units are the same. --- sympl/_core/units.py | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/sympl/_core/units.py b/sympl/_core/units.py index d2977e8..5451816 100644 --- a/sympl/_core/units.py +++ b/sympl/_core/units.py @@ -3,20 +3,20 @@ class UnitRegistry(pint.UnitRegistry): - def __call__(self, input_string, **kwargs): return super(UnitRegistry, self).__call__( - input_string.replace( - u'%', 'percent').replace( - u'°', 'degree' - ), - **kwargs) + input_string.replace(u"%", "percent").replace(u"°", "degree"), **kwargs + ) unit_registry = UnitRegistry() -unit_registry.define('degrees_north = degree_north = degree_N = degrees_N = degreeN = degreesN') -unit_registry.define('degrees_east = degree_east = degree_E = degrees_E = degreeE = degreesE') -unit_registry.define('percent = 0.01*count = %') +unit_registry.define( + "degrees_north = degree_north = degree_N = degrees_N = degreeN = degreesN" +) +unit_registry.define( + "degrees_east = degree_east = degree_E = degrees_E = degreeE = degreesE" +) +unit_registry.define("percent = 0.01*count = %") def units_are_compatible(unit1, unit2): @@ -63,9 +63,7 @@ def clean_units(unit_string): def is_valid_unit(unit_string): """Returns True if the unit string is recognized, and False otherwise.""" - unit_string = unit_string.replace( - '%', 'percent').replace( - '°', 'degree') + unit_string = unit_string.replace("%", "percent").replace("°", "degree") try: unit_registry(unit_string) except pint.UndefinedUnitError: @@ -75,16 +73,18 @@ def is_valid_unit(unit_string): def data_array_to_units(value, units): - if not hasattr(value, 'attrs') or 'units' not in value.attrs: - raise TypeError( - 'Cannot retrieve units from type {}'.format(type(value))) - elif unit_registry(value.attrs['units']) != unit_registry(units): - attrs = value.attrs.copy() - value = unit_registry.Quantity(value, value.attrs['units']).to(units).magnitude - attrs['units'] = units - value.attrs = attrs + if not hasattr(value, "attrs") or "units" not in value.attrs: + raise TypeError("Cannot retrieve units from type {}".format(type(value))) + elif value.attrs["units"] != units: + # elif unit_registry(value.attrs["units"]) != unit_registry(units): + out = value.copy() + out.data[...] = ( + unit_registry.convert(1, value.attrs["units"], units) * value.data + ) + out.attrs["units"] = units + value = out return value def from_unit_to_another(value, original_units, new_units): - return (unit_registry(original_units)*value).to(new_units).magnitude + return (unit_registry(original_units) * value).to(new_units).magnitude From 29b5f23638cd390e6c3d26722597c38a41d4920e Mon Sep 17 00:00:00 2001 From: Stefano Ubbiali Date: Tue, 12 Jan 2021 15:37:37 +0100 Subject: [PATCH 6/6] Cache calls to UnitRegistry. --- sympl/_core/units.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sympl/_core/units.py b/sympl/_core/units.py index 5451816..2278ce0 100644 --- a/sympl/_core/units.py +++ b/sympl/_core/units.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- +import functools import pint class UnitRegistry(pint.UnitRegistry): + @functools.lru_cache def __call__(self, input_string, **kwargs): return super(UnitRegistry, self).__call__( input_string.replace(u"%", "percent").replace(u"°", "degree"), **kwargs @@ -75,8 +77,7 @@ def is_valid_unit(unit_string): def data_array_to_units(value, units): if not hasattr(value, "attrs") or "units" not in value.attrs: raise TypeError("Cannot retrieve units from type {}".format(type(value))) - elif value.attrs["units"] != units: - # elif unit_registry(value.attrs["units"]) != unit_registry(units): + elif unit_registry(value.attrs["units"]) != unit_registry(units): out = value.copy() out.data[...] = ( unit_registry.convert(1, value.attrs["units"], units) * value.data