diff --git a/permacache/__init__.py b/permacache/__init__.py index b39da75..442aab4 100644 --- a/permacache/__init__.py +++ b/permacache/__init__.py @@ -2,4 +2,4 @@ from .cache_miss_error import CacheMissError, error_on_miss_global from .dict_function import drop_if, drop_if_equal from .hash import stable_hash, stringify -from .swap_unpickler import swap_unpickler_context_manager +from .swap_unpickler import renamed_symbol_unpickler, swap_unpickler_context_manager diff --git a/permacache/swap_unpickler.py b/permacache/swap_unpickler.py index bab775d..c3e3ee9 100644 --- a/permacache/swap_unpickler.py +++ b/permacache/swap_unpickler.py @@ -1,5 +1,6 @@ import pickle import shelve +from typing import Dict, Tuple, Union class swap_unpickler_context_manager: @@ -18,3 +19,39 @@ def __exit__(self, typ, value, traceback): self._previous_unpickler = None else: shelve.Unpickler = pickle.Unpickler + + +def renamed_symbol_unpickler( + symbol_rename_map: Dict[Tuple[str, str], Union[Tuple[str, str], type]] +) -> type: + """ + Returns an unpickler class that renames symbols as specified in + the symbol_rename_map dictionaries. + + :param symbol_rename_map: A dictionary mapping (module, name) pairs to + (new_module, new_name) pairs. Can also map to a type, in which case + we convert the type to a (module, name) pair. + """ + + symbol_rename_map_string = {} + for (module, name), new_symbol in symbol_rename_map.items(): + if isinstance(new_symbol, type): + new_symbol = (new_symbol.__module__, new_symbol.__name__) + assert ( + isinstance(new_symbol, tuple) + and len(new_symbol) == 2 + and all(isinstance(x, str) for x in new_symbol) + ), f"Invalid new symbol: {new_symbol}" + symbol_rename_map_string[(module, name)] = new_symbol + + class RenamedSymbolUnpickler(pickle.Unpickler): + def find_class(self, module, name): + if (module, name) in symbol_rename_map_string: + module, name = symbol_rename_map_string[(module, name)] + try: + return super().find_class(module, name) + except: + print("Could not find", (module, name)) + raise + + return RenamedSymbolUnpickler diff --git a/setup.py b/setup.py index 79df7ee..1ccd008 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="permacache", - version="3.7.3", + version="3.8.0", author="Kavi Gupta", author_email="permacache@kavigupta.org", description="Permanant cache.", diff --git a/tests/test_cache.py b/tests/test_cache.py index 68c2cab..da9e9bb 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -2,6 +2,14 @@ import unittest from permacache import cache +from permacache.swap_unpickler import ( + renamed_symbol_unpickler, + swap_unpickler_context_manager, +) + +from .test_module.a import X as a_X +from .test_module.a import Y as a_Y +from .test_module.b import X as b_X # pylint: disable=keyword-arg-before-vararg @@ -35,3 +43,80 @@ def test_basic(self): self.assertEqual(fn.counter, 2) self.assertEqual(self.f(3, 2, 1, 0, -1), (3, 2, 1, (0, -1))) self.assertEqual(fn.counter, 3) + + +def g(x): + g.counter += 1 + return a_X(x) + + +class PermacacheRemappingTest(unittest.TestCase): + def setUp(self): + # we clean this up in tearDown + # pylint: disable=consider-using-with + self.dir = tempfile.TemporaryDirectory() + cache.CACHE = self.dir.name + g.counter = 0 + self.f = cache.permacache("g")(g) + self.assertIsInstance(self.f(1), a_X) + self.assertEqual(g.counter, 1) + + def tearDown(self): + self.dir.__exit__(None, None, None) + + def test_swap_to_b(self): + # populate cache + # swap to b + self.f = cache.permacache( + "g", + read_from_shelf_context_manager=swap_unpickler_context_manager( + renamed_symbol_unpickler( + { + ( + "tests.test_module.a", + "X", + ): b_X + } + ) + ), + )(g) + self.assertIsInstance(self.f(1), b_X) + self.assertEqual(g.counter, 1) + + def test_swap_to_y(self): + # populate cache + # swap to y + self.f = cache.permacache( + "g", + read_from_shelf_context_manager=swap_unpickler_context_manager( + renamed_symbol_unpickler( + { + ( + "tests.test_module.a", + "X", + ): a_Y + } + ) + ), + )(g) + self.assertIsInstance(self.f(1), a_Y) + self.assertEqual(g.counter, 1) + + def test_use_name(self): + # populate cache + # swap to y + self.f = cache.permacache( + "g", + read_from_shelf_context_manager=swap_unpickler_context_manager( + renamed_symbol_unpickler( + { + ( + "tests.test_module.a", + "X", + ): ("tests.test_module.a", "Y") + } + ) + ), + )(g) + self.assertIsInstance(self.f(1), a_Y) + self.assertEqual(g.counter, 1) diff --git a/tests/test_module/__init__.py b/tests/test_module/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_module/a.py b/tests/test_module/a.py new file mode 100644 index 0000000..519e0c6 --- /dev/null +++ b/tests/test_module/a.py @@ -0,0 +1,8 @@ +class X: + def __init__(self, x): + self.x = x + + +class Y: + def __init__(self, x): + self.x = x diff --git a/tests/test_module/b.py b/tests/test_module/b.py new file mode 100644 index 0000000..b12fc18 --- /dev/null +++ b/tests/test_module/b.py @@ -0,0 +1,3 @@ +class X: + def __init__(self, x): + self.x = x