diff --git a/src/MEDS_transforms/mapreduce/utils.py b/src/MEDS_transforms/mapreduce/utils.py index 70af370..b2598b6 100644 --- a/src/MEDS_transforms/mapreduce/utils.py +++ b/src/MEDS_transforms/mapreduce/utils.py @@ -129,6 +129,20 @@ def rwlock_wrap( polars.exceptions.ColumnNotFoundError: unable to find column "d"; valid columns: ["a", "b", "c"] ... >>> assert not out_fp.is_file() # Out file should not be created when the process crashes + + If the lock file already exists, the function will not do anything + >>> def compute_fn(df: pl.DataFrame) -> pl.DataFrame: + ... return df.with_columns(pl.col("c") * 2).filter(pl.col("c") > 4) + >>> out_fp = root / "output.csv" + >>> lock_fp = root / "output.csv.lock" + >>> with FileLock(str(lock_fp)): + ... result_computed = rwlock_wrap(in_fp, out_fp, read_fn, write_fn, compute_fn) + ... assert not result_computed + + The lock file will be removed after successful processing. + >>> result_computed = rwlock_wrap(in_fp, out_fp, read_fn, write_fn, compute_fn) + >>> assert result_computed + >>> assert not lock_fp.exists() """ if out_fp_checker(out_fp): @@ -139,8 +153,8 @@ def rwlock_wrap( logger.info(f"{out_fp} exists; returning.") return False - lock_fp = str(out_fp) + ".lock" - lock = FileLock(lock_fp) + lock_fp = out_fp.with_suffix(f"{out_fp.suffix}.lock") + lock = FileLock(str(lock_fp)) try: lock.acquire(timeout=0) except Timeout: @@ -159,6 +173,7 @@ def rwlock_wrap( return True finally: lock.release() + lock_fp.unlink() def shuffle_shards(shards: list[str], cfg: DictConfig) -> list[str]: