Skip to content

Commit

Permalink
Built in pickle remapper (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
kavigupta authored Oct 15, 2024
1 parent 36c6768 commit 08a1eb4
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 2 deletions.
2 changes: 1 addition & 1 deletion permacache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from .cache_miss_error import CacheMissError, error_on_miss_global
from .dict_function import drop_if, drop_if_equal
from .hash import stable_hash, stringify
from .swap_unpickler import swap_unpickler_context_manager
from .swap_unpickler import renamed_symbol_unpickler, swap_unpickler_context_manager
37 changes: 37 additions & 0 deletions permacache/swap_unpickler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle
import shelve
from typing import Dict, Tuple, Union


class swap_unpickler_context_manager:
Expand All @@ -18,3 +19,39 @@ def __exit__(self, typ, value, traceback):
self._previous_unpickler = None
else:
shelve.Unpickler = pickle.Unpickler


def renamed_symbol_unpickler(
symbol_rename_map: Dict[Tuple[str, str], Union[Tuple[str, str], type]]
) -> type:
"""
Returns an unpickler class that renames symbols as specified in
the symbol_rename_map dictionaries.
:param symbol_rename_map: A dictionary mapping (module, name) pairs to
(new_module, new_name) pairs. Can also map to a type, in which case
we convert the type to a (module, name) pair.
"""

symbol_rename_map_string = {}
for (module, name), new_symbol in symbol_rename_map.items():
if isinstance(new_symbol, type):
new_symbol = (new_symbol.__module__, new_symbol.__name__)
assert (
isinstance(new_symbol, tuple)
and len(new_symbol) == 2
and all(isinstance(x, str) for x in new_symbol)
), f"Invalid new symbol: {new_symbol}"
symbol_rename_map_string[(module, name)] = new_symbol

class RenamedSymbolUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if (module, name) in symbol_rename_map_string:
module, name = symbol_rename_map_string[(module, name)]
try:
return super().find_class(module, name)
except:
print("Could not find", (module, name))
raise

return RenamedSymbolUnpickler
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="permacache",
version="3.7.3",
version="3.8.0",
author="Kavi Gupta",
author_email="[email protected]",
description="Permanant cache.",
Expand Down
85 changes: 85 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
import unittest

from permacache import cache
from permacache.swap_unpickler import (
renamed_symbol_unpickler,
swap_unpickler_context_manager,
)

from .test_module.a import X as a_X
from .test_module.a import Y as a_Y
from .test_module.b import X as b_X


# pylint: disable=keyword-arg-before-vararg
Expand Down Expand Up @@ -35,3 +43,80 @@ def test_basic(self):
self.assertEqual(fn.counter, 2)
self.assertEqual(self.f(3, 2, 1, 0, -1), (3, 2, 1, (0, -1)))
self.assertEqual(fn.counter, 3)


def g(x):
g.counter += 1
return a_X(x)


class PermacacheRemappingTest(unittest.TestCase):
def setUp(self):
# we clean this up in tearDown
# pylint: disable=consider-using-with
self.dir = tempfile.TemporaryDirectory()
cache.CACHE = self.dir.name
g.counter = 0
self.f = cache.permacache("g")(g)
self.assertIsInstance(self.f(1), a_X)
self.assertEqual(g.counter, 1)

def tearDown(self):
self.dir.__exit__(None, None, None)

def test_swap_to_b(self):
# populate cache
# swap to b
self.f = cache.permacache(
"g",
read_from_shelf_context_manager=swap_unpickler_context_manager(
renamed_symbol_unpickler(
{
(
"tests.test_module.a",
"X",
): b_X
}
)
),
)(g)
self.assertIsInstance(self.f(1), b_X)
self.assertEqual(g.counter, 1)

def test_swap_to_y(self):
# populate cache
# swap to y
self.f = cache.permacache(
"g",
read_from_shelf_context_manager=swap_unpickler_context_manager(
renamed_symbol_unpickler(
{
(
"tests.test_module.a",
"X",
): a_Y
}
)
),
)(g)
self.assertIsInstance(self.f(1), a_Y)
self.assertEqual(g.counter, 1)

def test_use_name(self):
# populate cache
# swap to y
self.f = cache.permacache(
"g",
read_from_shelf_context_manager=swap_unpickler_context_manager(
renamed_symbol_unpickler(
{
(
"tests.test_module.a",
"X",
): ("tests.test_module.a", "Y")
}
)
),
)(g)
self.assertIsInstance(self.f(1), a_Y)
self.assertEqual(g.counter, 1)
Empty file added tests/test_module/__init__.py
Empty file.
8 changes: 8 additions & 0 deletions tests/test_module/a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class X:
def __init__(self, x):
self.x = x


class Y:
def __init__(self, x):
self.x = x
3 changes: 3 additions & 0 deletions tests/test_module/b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class X:
def __init__(self, x):
self.x = x

0 comments on commit 08a1eb4

Please sign in to comment.