Skip to content

Commit

Permalink
fix: ChecksumAddressSingletonMeta race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler committed Aug 8, 2022
1 parent 51da1f6 commit 2c00359
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions checksum_dict/singleton.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,43 @@
import threading
from collections import defaultdict
from typing import Any, DefaultDict, Generic
from typing import Any, DefaultDict, Dict, Generic, Tuple

from checksum_dict.base import AnyAddressOrContract, ChecksumAddressDict, T


_LocksDict = DefaultDict[AnyAddressOrContract, threading.Lock]

class ChecksumAddressSingletonMeta(type, Generic[T]):
__locks: _LocksDict = defaultdict(threading.Lock)
__locks_lock: threading.Lock = threading.Lock()
__instances: ChecksumAddressDict[T] = ChecksumAddressDict()
def __init__(cls, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]) -> None:
super().__init__(name, bases, namespace)
cls.__instances: ChecksumAddressDict[T] = ChecksumAddressDict()
cls.__locks: _LocksDict = defaultdict(threading.Lock)
cls.__locks_lock: threading.Lock = threading.Lock()

def __call__(self, address: AnyAddressOrContract, *args: Any, **kwargs: Any) -> T: # type: ignore
def __call__(cls, address: AnyAddressOrContract, *args: Any, **kwargs: Any) -> T: # type: ignore
address = str(address)
try:
instance = self.__instances[address]
instance = cls.__instances[address]
except KeyError:
with self.__get_address_lock(address):
with cls.__get_address_lock(address):
# Try to get the instance again, in case it was added while waiting for the lock
try:
instance = self.__instances[address]
instance = cls.__instances[address]
except KeyError:
instance = super().__call__(address, *args, **kwargs)
self.__instances[address] = instance
self.__delete_address_lock(address)
cls.__instances[address] = instance
cls.__delete_address_lock(address)
return instance

def __get_address_lock(self, address: AnyAddressOrContract) -> threading.Lock:
def __get_address_lock(cls, address: AnyAddressOrContract) -> threading.Lock:
""" Makes sure the singleton is actually a singleton. """
with self.__locks_lock:
return self.__locks[address]
with cls.__locks_lock:
return cls.__locks[address]

def __delete_address_lock(self, address: AnyAddressOrContract) -> None:
def __delete_address_lock(cls, address: AnyAddressOrContract) -> None:
""" No need to maintain locks for initialized addresses. """
with self.__locks_lock:
with cls.__locks_lock:
try:
del self.__locks[address]
del cls.__locks[address]
except KeyError:
pass

0 comments on commit 2c00359

Please sign in to comment.