Skip to content

Commit

Permalink
fix(flock): Make it write-write re-entrant as well
Browse files Browse the repository at this point in the history
  • Loading branch information
Rizhiy committed Jun 1, 2024
1 parent 435283b commit 558de35
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 47 deletions.
35 changes: 18 additions & 17 deletions replete/flock.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@ def __init__(self, file_path: Path):
self._file_path.parent.mkdir(parents=True, exist_ok=True)
self._file_path.touch(exist_ok=True)

self._locked = False
self._fd: TextIOWrapper | None = None
self._lock_code: int | None = None
self._write_locks: list[FileLock] = []
self._read_locks: list[FileLock] = []
self._dependent_locks: list[FileLock] = []

@property
def file_path(self) -> Path:
return self._file_path

@property
def locked(self) -> bool:
return self._fd is not None
return self._locked

def _get_locked_write_lock_or_none(self) -> FileLock | None:
if self._write_locks and any(lock.locked for lock in self._write_locks):
Expand All @@ -52,6 +53,7 @@ def write_lock(self, *, non_blocking=False) -> FileLock:
raise ValueError("Can't get read lock when lock is already acquired")
self_copy = copy.copy(self)
self_copy._lock_code = fcntl.LOCK_EX # noqa: SLF001
self_copy._write_locks = self._write_locks # noqa: SLF001
if non_blocking:
self_copy._lock_code |= fcntl.LOCK_NB # noqa: SLF001
self._write_locks.append(self_copy)
Expand All @@ -62,15 +64,15 @@ def acquire(self) -> None:
raise ValueError("Can't acquire bare lock, please use either read_lock or write_lock")

self._fd = self._file_path.open()
# If write lock already exists and locked, we can give read lock
if self._lock_code in {fcntl.LOCK_SH, fcntl.LOCK_SH | fcntl.LOCK_NB}:
write_lock = self._get_locked_write_lock_or_none()
if write_lock:
write_lock._read_locks.append(self) # noqa: SLF001
return
write_lock = self._get_locked_write_lock_or_none()
if write_lock:
write_lock._dependent_locks.append(self) # noqa: SLF001
self._locked = True
return

try:
fcntl.flock(self._fd, self._lock_code)
self._locked = True
except BlockingIOError:
# Release resources if we fail acquiring non-blocking lock
self._fd.close()
Expand All @@ -81,19 +83,18 @@ def release(self) -> None:
if not self._fd:
raise ValueError(f"Lock on {self._file_path} is not acquired and cannot be released")

if self._lock_code in {fcntl.LOCK_EX, fcntl.LOCK_EX | fcntl.LOCK_NB} and self._read_locks:
if self._lock_code in {fcntl.LOCK_EX, fcntl.LOCK_EX | fcntl.LOCK_NB} and self._dependent_locks:
raise ValueError(
"Found unreleased read lock, please release all read locks before releasing main write lock!",
"Found unreleased dependent lock, please release all read locks before releasing main write lock!",
)

# Check if read lock was acquired over write lock and remove ourselves from dependencies
if self._lock_code in {fcntl.LOCK_SH, fcntl.LOCK_SH | fcntl.LOCK_NB}:
write_lock = self._get_locked_write_lock_or_none()
if write_lock:
write_lock._read_locks.remove(self) # noqa: SLF001
self._fd.close()
self._fd = None
return

self._locked = False
write_lock = self._get_locked_write_lock_or_none()
if write_lock:
write_lock._dependent_locks.remove(self) # noqa: SLF001
return

fcntl.flock(self._fd, fcntl.LOCK_UN)
self._fd.close()
Expand Down
46 changes: 16 additions & 30 deletions tests/test_flock.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,53 +48,39 @@ def test_parallel(tmp_path, executor_cls):


def test_read_read_reentrant(tmp_path):
path = tmp_path / "tmp.txt"
with path.open("w") as f:
f.write("TEST\n")
lock_path = path.parent / f"{path.name}.lock"
lock_path = tmp_path / "file.lock"
lock = FileLock(lock_path)
with lock.read_lock(), lock.read_lock(), path.open("r") as f:
assert f.read().strip() == "TEST"
with lock.read_lock(), lock.read_lock(non_blocking=True):
pass


# If we have already acquired a write lock (exclusive), we should be able to get a read lock
def test_write_read_reentrant(tmp_path):
path = tmp_path / "tmp.txt"
with path.open("w") as f:
f.write("TEST\n")
lock_path = path.parent / f"{path.name}.lock"
lock_path = tmp_path / "file.lock"
lock = FileLock(lock_path)
with (
lock.write_lock(),
lock.read_lock(non_blocking=True),
path.open("r") as f,
):
assert f.read().strip() == "TEST"
with lock.write_lock(), lock.read_lock(non_blocking=True):
pass


def test_read_write_error(tmp_path):
path = tmp_path / "tmp.txt"
lock_path = path.parent / f"{path.name}.lock"
def test_write_write_reentrant(tmp_path):
lock_path = tmp_path / "file.lock"
lock = FileLock(lock_path)
with pytest.raises(BlockingIOError), lock.read_lock():
lock.write_lock(non_blocking=True).acquire()
with lock.write_lock(), lock.write_lock(non_blocking=True):
pass


def test_write_write_error(tmp_path):
path = tmp_path / "tmp.txt"
lock_path = path.parent / f"{path.name}.lock"
def test_read_write_error(tmp_path):
lock_path = tmp_path / "file.lock"
lock = FileLock(lock_path)
with pytest.raises(BlockingIOError), lock.write_lock():
with pytest.raises(BlockingIOError), lock.read_lock():
lock.write_lock(non_blocking=True).acquire()


def test_write_read_reentrant_wrong_release_order_error(tmp_path):
path = tmp_path / "tmp.txt"
lock_path = path.parent / f"{path.name}.lock"
def test_reentrant_wrong_release_order_error(tmp_path):
lock_path = tmp_path / "file.lock"
lock = FileLock(lock_path)
write_lock = lock.write_lock()
read_lock = lock.read_lock()
write_lock.acquire()
read_lock.acquire()
with pytest.raises(ValueError, match="unreleased read lock"):
with pytest.raises(ValueError, match="unreleased dependent lock"):
write_lock.release()

0 comments on commit 558de35

Please sign in to comment.