From 558de35e543bf1aa93a35ecfba7558fde5d1c3f5 Mon Sep 17 00:00:00 2001 From: Artem Vasenin Date: Sat, 1 Jun 2024 01:04:33 +0100 Subject: [PATCH] fix(flock): Make it write-write re-entrant as well --- replete/flock.py | 35 +++++++++++++++++----------------- tests/test_flock.py | 46 ++++++++++++++++----------------------------- 2 files changed, 34 insertions(+), 47 deletions(-) diff --git a/replete/flock.py b/replete/flock.py index 8f4ae41..db6eda8 100644 --- a/replete/flock.py +++ b/replete/flock.py @@ -19,10 +19,11 @@ def __init__(self, file_path: Path): self._file_path.parent.mkdir(parents=True, exist_ok=True) self._file_path.touch(exist_ok=True) + self._locked = False self._fd: TextIOWrapper | None = None self._lock_code: int | None = None self._write_locks: list[FileLock] = [] - self._read_locks: list[FileLock] = [] + self._dependent_locks: list[FileLock] = [] @property def file_path(self) -> Path: @@ -30,7 +31,7 @@ def file_path(self) -> Path: @property def locked(self) -> bool: - return self._fd is not None + return self._locked def _get_locked_write_lock_or_none(self) -> FileLock | None: if self._write_locks and any(lock.locked for lock in self._write_locks): @@ -52,6 +53,7 @@ def write_lock(self, *, non_blocking=False) -> FileLock: raise ValueError("Can't get read lock when lock is already acquired") self_copy = copy.copy(self) self_copy._lock_code = fcntl.LOCK_EX # noqa: SLF001 + self_copy._write_locks = self._write_locks # noqa: SLF001 if non_blocking: self_copy._lock_code |= fcntl.LOCK_NB # noqa: SLF001 self._write_locks.append(self_copy) @@ -62,15 +64,15 @@ def acquire(self) -> None: raise ValueError("Can't acquire bare lock, please use either read_lock or write_lock") self._fd = self._file_path.open() - # If write lock already exists and locked, we can give read lock - if self._lock_code in {fcntl.LOCK_SH, fcntl.LOCK_SH | fcntl.LOCK_NB}: - write_lock = self._get_locked_write_lock_or_none() - if write_lock: - write_lock._read_locks.append(self) # noqa: SLF001 - return + write_lock = self._get_locked_write_lock_or_none() + if write_lock: + write_lock._dependent_locks.append(self) # noqa: SLF001 + self._locked = True + return try: fcntl.flock(self._fd, self._lock_code) + self._locked = True except BlockingIOError: # Release resources if we fail acquiring non-blocking lock self._fd.close() @@ -81,19 +83,18 @@ def release(self) -> None: if not self._fd: raise ValueError(f"Lock on {self._file_path} is not acquired and cannot be released") - if self._lock_code in {fcntl.LOCK_EX, fcntl.LOCK_EX | fcntl.LOCK_NB} and self._read_locks: + if self._lock_code in {fcntl.LOCK_EX, fcntl.LOCK_EX | fcntl.LOCK_NB} and self._dependent_locks: raise ValueError( - "Found unreleased read lock, please release all read locks before releasing main write lock!", + "Found unreleased dependent lock, please release all read locks before releasing main write lock!", ) # Check if read lock was acquired over write lock and remove ourselves from dependencies - if self._lock_code in {fcntl.LOCK_SH, fcntl.LOCK_SH | fcntl.LOCK_NB}: - write_lock = self._get_locked_write_lock_or_none() - if write_lock: - write_lock._read_locks.remove(self) # noqa: SLF001 - self._fd.close() - self._fd = None - return + + self._locked = False + write_lock = self._get_locked_write_lock_or_none() + if write_lock: + write_lock._dependent_locks.remove(self) # noqa: SLF001 + return fcntl.flock(self._fd, fcntl.LOCK_UN) self._fd.close() diff --git a/tests/test_flock.py b/tests/test_flock.py index 855435c..ca62636 100644 --- a/tests/test_flock.py +++ b/tests/test_flock.py @@ -48,53 +48,39 @@ def test_parallel(tmp_path, executor_cls): def test_read_read_reentrant(tmp_path): - path = tmp_path / "tmp.txt" - with path.open("w") as f: - f.write("TEST\n") - lock_path = path.parent / f"{path.name}.lock" + lock_path = tmp_path / "file.lock" lock = FileLock(lock_path) - with lock.read_lock(), lock.read_lock(), path.open("r") as f: - assert f.read().strip() == "TEST" + with lock.read_lock(), lock.read_lock(non_blocking=True): + pass -# If we have already acquired a write lock (exclusive), we should be able to get a read lock def test_write_read_reentrant(tmp_path): - path = tmp_path / "tmp.txt" - with path.open("w") as f: - f.write("TEST\n") - lock_path = path.parent / f"{path.name}.lock" + lock_path = tmp_path / "file.lock" lock = FileLock(lock_path) - with ( - lock.write_lock(), - lock.read_lock(non_blocking=True), - path.open("r") as f, - ): - assert f.read().strip() == "TEST" + with lock.write_lock(), lock.read_lock(non_blocking=True): + pass -def test_read_write_error(tmp_path): - path = tmp_path / "tmp.txt" - lock_path = path.parent / f"{path.name}.lock" +def test_write_write_reentrant(tmp_path): + lock_path = tmp_path / "file.lock" lock = FileLock(lock_path) - with pytest.raises(BlockingIOError), lock.read_lock(): - lock.write_lock(non_blocking=True).acquire() + with lock.write_lock(), lock.write_lock(non_blocking=True): + pass -def test_write_write_error(tmp_path): - path = tmp_path / "tmp.txt" - lock_path = path.parent / f"{path.name}.lock" +def test_read_write_error(tmp_path): + lock_path = tmp_path / "file.lock" lock = FileLock(lock_path) - with pytest.raises(BlockingIOError), lock.write_lock(): + with pytest.raises(BlockingIOError), lock.read_lock(): lock.write_lock(non_blocking=True).acquire() -def test_write_read_reentrant_wrong_release_order_error(tmp_path): - path = tmp_path / "tmp.txt" - lock_path = path.parent / f"{path.name}.lock" +def test_reentrant_wrong_release_order_error(tmp_path): + lock_path = tmp_path / "file.lock" lock = FileLock(lock_path) write_lock = lock.write_lock() read_lock = lock.read_lock() write_lock.acquire() read_lock.acquire() - with pytest.raises(ValueError, match="unreleased read lock"): + with pytest.raises(ValueError, match="unreleased dependent lock"): write_lock.release()