From c55c41ce20c4a5982dd403a66f42ab710520d70a Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Wed, 12 Jun 2024 08:29:48 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 642623358 --- weatherbench2/evaluation.py | 1 + weatherbench2/metrics.py | 2 ++ weatherbench2/utils.py | 48 ++++++++++++++++++++++++++++++++++++- weatherbench2/utils_test.py | 35 +++++++++++++++++++++++++++ 4 files changed, 85 insertions(+), 1 deletion(-) diff --git a/weatherbench2/evaluation.py b/weatherbench2/evaluation.py index 9b4ab60..0718873 100644 --- a/weatherbench2/evaluation.py +++ b/weatherbench2/evaluation.py @@ -390,6 +390,7 @@ def _metric_and_region_loop( ) -> xr.Dataset: """Compute metric results looping over metrics and regions in eval config.""" # Compute derived variables + logging.info('Starting _metric_and_region_loop') for name, dv in eval_config.derived_variables.items(): logging.info(f'Logging: derived_variable {name!r}: {dv}') forecast[name] = dv.compute(forecast) diff --git a/weatherbench2/metrics.py b/weatherbench2/metrics.py index 98bbb4d..6b9ab57 100644 --- a/weatherbench2/metrics.py +++ b/weatherbench2/metrics.py @@ -24,6 +24,7 @@ import numpy as np from scipy import stats from weatherbench2 import thresholds +from weatherbench2 import utils from weatherbench2.regions import Region import xarray as xr @@ -705,6 +706,7 @@ def compute_chunk( return _pointwise_crps_skill(forecast, truth, self.ensemble_dim) +@utils.id_lru_cache(maxsize=1) def _pointwise_crps_spread( forecast: xr.Dataset, truth: xr.Dataset, ensemble_dim: str ) -> xr.Dataset: diff --git a/weatherbench2/utils.py b/weatherbench2/utils.py index 7d12a7f..d58ce8d 100644 --- a/weatherbench2/utils.py +++ b/weatherbench2/utils.py @@ -13,7 +13,8 @@ # limitations under the License. # ============================================================================== """Utility function for WeatherBench2.""" -from typing import Callable, Union +import functools +from typing import Callable, Hashable, Union import fsspec import numpy as np @@ -292,3 +293,48 @@ def random_like(dataset: xr.Dataset, seed: int = 0) -> xr.Dataset: return dataset.copy( data={k: rs.normal(size=v.shape) for k, v in dataset.items()} ) + + +def id_lru_cache(maxsize: int = 5): + """Like functools.lru_cache but uses argument id for non-hashables. + + Warning: This is not threadsafe. Multiple threads reading/writing to the cache + results in inconsistent behavior. + + Args: + maxsize: Maximum size of cache. + + Returns: + Decorator to make a function into a caching function. + """ + + def hashid(x): + if isinstance(x, Hashable): + return hash(x) + return id(x) + + def decorating_function(func): + cache = {} + + @functools.wraps(func) + def wrapper(*args, **kwargs): + key = tuple(hashid(a) for a in args) + tuple( + (k, hashid(v)) for k, v in kwargs.items() + ) + # Python dicts are ordered in the sense that if keys = list(my_dict.key()) + # then keys[0] is the first key added, and keys[-1] is the most recently + # added. + if key in cache: + # Move cache[key] to position -1 since it is most recently used. + value = cache[key] + del cache[key] + cache[key] = value + else: + if len(cache) >= maxsize: + cache.pop(list(cache)[0]) # Pop first item added to dictionary. + cache[key] = func(*args, **kwargs) + return cache[key] + + return wrapper + + return decorating_function diff --git a/weatherbench2/utils_test.py b/weatherbench2/utils_test.py index 4bb0d3c..b3992f9 100644 --- a/weatherbench2/utils_test.py +++ b/weatherbench2/utils_test.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== from absl.testing import absltest +import numpy as np from weatherbench2 import schema from weatherbench2 import utils import xarray @@ -67,5 +68,39 @@ def testProbabilisticClimatology(self): self.assertEqual(clim['2m_temperature'].sizes, expected_sizes) +class IdLRUCacheTest(absltest.TestCase): + + def test_handles_non_hashable_args_and_kwargs(self): + + @utils.id_lru_cache(maxsize=2) + def func(x: np.ndarray, y: np.ndarray, b: float = 1): + return np.sum(x + y * b) + + # Use 3 sets of arrays so we are sure to cycle through the size 2 cache. + with self.subTest('First set of arrays'): + x = np.array([1.0, 2.0, 3.0]) + y = x + 2 + b = 1.3 + expected = np.sum(x + y * b) + for _ in range(4): + self.assertEqual(expected, func(x, y, b=b)) + + with self.subTest('Second set of arrays'): + x = np.array([0.0, -2.0, 0.123]) + y = np.array([10.0, -1.0, 3]) + b = 10.3 + expected = np.sum(x + y * b) + for _ in range(4): + self.assertEqual(expected, func(x, y, b=b)) + + with self.subTest('Third set of arrays'): + x = np.array([0.0, -20.0]) + y = np.array([10.0, -11.0]) + b = -1234 + expected = np.sum(x + y * b) + for _ in range(4): + self.assertEqual(expected, func(x, y, b=b)) + + if __name__ == '__main__': absltest.main()