Skip to content

Commit

Permalink
Refactored dask utility function
Browse files Browse the repository at this point in the history
  • Loading branch information
MitchellAV committed May 16, 2024
1 parent 886f778 commit dc4cba5
Showing 1 changed file with 143 additions and 89 deletions.
232 changes: 143 additions & 89 deletions workers/src/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,44 +64,55 @@ def multiprocess(
return results


def dask_multiprocess(
func: Callable[..., T],
func_arguments: list[tuple[Any, ...]],
def logger_if_able(
message: str, logger: Logger | None = None, level: str = "INFO"
):
if logger is not None:
levels_dict = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}

level = level.upper()

if level not in levels_dict:
raise Exception(f"Invalid log level: {level}")

log_level = levels_dict[level]

logger.log(log_level, message)
else:
print(message)


MEMORY_PER_RUN = 7.0 # in GB


def set_workers_and_threads(
cpu_count: int | None,
sys_memory: float,
memory_per_run: float | int,
n_workers: int | None = None,
threads_per_worker: int | None = None,
memory_per_run: float | int | None = None,
logger: Logger | None = None,
**kwargs,
) -> T | list[T] | tuple[T, ...]:

# if n_workers is None:
# n_workers = os.cpu_count()
# if n_workers is None:
# msg = (
# "Could not determine number of CPUs. Defaulting to 4 workers."
# )
# if logger:
# logger.warning(msg)
# else:
# print(msg)
# n_workers = 4

# if threads_per_worker is None:
# threads_per_worker = None

# if n_workers is None:
# n_workers = cpu_count
# if n_workers * n_processes > cpu_count:
# raise Exception(f"workers and threads exceed local resources, {cpu_count} cores present")
# if n_workers * memory_limit > sys_memory:
# config.set({'distributed.worker.memory.spill': True})
# print(f"Memory per worker exceeds system memory ({memory_limit} GB), activating memory spill\n")

memory_per_run = memory_per_run or 7.0
) -> Tuple[int, int]:

cpu_count = os.cpu_count()
# memory limit in GB
sys_memory = psutil.virtual_memory().total / (1024.0**3)
def handle_exceeded_resources(
n_workers, threads_per_worker, memory_per_run, sys_memory
):
if memory_per_run * n_workers * threads_per_worker > sys_memory:
config.set({"distributed.worker.memory.spill": True})
logger_if_able(
f"Memory per worker exceeds system memory ({memory_per_run} GB), activating memory spill",
logger,
"WARNING",
)

total_workers: int = 1
total_threads: int = 1

if cpu_count is None:
raise Exception("Could not determine number of CPUs.")
Expand All @@ -111,106 +122,149 @@ def dask_multiprocess(
raise Exception(
f"workers and threads exceed local resources, {cpu_count} cores present"
)
if memory_per_run * n_workers * threads_per_worker > sys_memory:
config.set({"distributed.worker.memory.spill": True})
print(
f"Memory per worker exceeds system memory ({memory_per_run} GB), activating memory spill\n"
)
handle_exceeded_resources(
n_workers, threads_per_worker, memory_per_run, sys_memory
)
total_workers, total_threads = n_workers, threads_per_worker

if n_workers is not None and threads_per_worker is None:
threads_per_worker = int(
math.floor(sys_memory / (memory_per_run * n_workers))
)
if threads_per_worker == 0:
print(
"Not enough memory for a worker, defaulting to 1 thread per worker"
logger_if_able(
"Not enough memory for a worker, defaulting to 1 thread per worker",
logger,
"WARNING",
)
threads_per_worker = 1
handle_exceeded_resources(
n_workers, threads_per_worker, memory_per_run, sys_memory
)

total_workers, total_threads = n_workers, threads_per_worker

if n_workers is None and threads_per_worker is not None:
n_workers = int(
math.floor(sys_memory / (memory_per_run * threads_per_worker))
)
if n_workers == 0:
print("Not enough memory for a worker, defaulting to 1 worker")
logger_if_able(
"Not enough memory for a worker, defaulting to 1 worker",
logger,
"WARNING",
)
n_workers = 1
handle_exceeded_resources(
n_workers, threads_per_worker, memory_per_run, sys_memory
)

total_workers, total_threads = n_workers, threads_per_worker

if n_workers is None and threads_per_worker is None:
if memory_per_run == 0:
raise Exception("Memory limit cannot be 0")

thread_worker_total = math.floor(sys_memory / memory_per_run)
if thread_worker_total < 2:
print(
"Not enough memory for a worker, defaulting to 1 worker and 1 thread per worker"
logger_if_able(
"Not enough memory for a worker, defaulting to 1 worker and 1 thread per worker",
logger,
"WARNING",
)
n_workers = 1
threads_per_worker = 1
if memory_per_run * n_workers > sys_memory:
config.set({"distributed.worker.memory.spill": True})
print(
f"Memory per worker exceeds system memory ({memory_per_run} GB), activating memory spill\n"
)
handle_exceeded_resources(
n_workers, threads_per_worker, memory_per_run, sys_memory
)

total_workers, total_threads = n_workers, threads_per_worker
return total_workers, total_threads
else:
print(f"thread_worker_total: {thread_worker_total}")
logger_if_able(
f"thread_worker_total: {thread_worker_total}",
logger,
"DEBUG",
)
n_workers = int(math.ceil(thread_worker_total / 2))
threads_per_worker = int(math.floor(thread_worker_total / 2))
if n_workers + threads_per_worker != thread_worker_total:
print(
f"n_workers: {n_workers}, threads_per_worker: {threads_per_worker}, thread_worker_total: {thread_worker_total}"
logger_if_able(
f"n_workers: {n_workers}, threads_per_worker: {threads_per_worker}, thread_worker_total: {thread_worker_total}",
logger,
"INFO",
)
logger_if_able(
"Could not determine number of workers and threads",
logger,
"ERROR",
)
raise Exception(
"Could not determine number of workers and threads"
)
handle_exceeded_resources(
n_workers, threads_per_worker, memory_per_run, sys_memory
)

total_workers, total_threads = n_workers, threads_per_worker
return total_workers, total_threads


def dask_multiprocess(
func: Callable[..., T],
func_arguments: list[tuple[Any, ...]],
n_workers: int | None = None,
threads_per_worker: int | None = None,
memory_per_run: float | int | None = None,
logger: Logger | None = None,
**kwargs,
):
memory_per_run = memory_per_run or MEMORY_PER_RUN

cpu_count = os.cpu_count()

sys_memory = psutil.virtual_memory().total / (1024.0**3) # in GB

# config.set({"distributed.worker.memory.spill": True})
config.set({"distributed.worker.memory.pause": True})
config.set({"distributed.worker.memory.target": 0.95})
config.set({"distributed.worker.memory.terminate": False})

if threads_per_worker is None:
threads_per_worker = 1
total_workers, total_threads = set_workers_and_threads(
cpu_count,
sys_memory,
memory_per_run,
n_workers,
threads_per_worker,
logger,
)

memory_per_worker = memory_per_run * threads_per_worker
memory_per_worker = memory_per_run * total_threads

print(f"cpu count: {cpu_count}")
print(f"memory: {sys_memory}")
print(f"memory per run: {memory_per_run}")
print(f"n_workers: {n_workers}")
print(f"threads_per_worker: {threads_per_worker}")
print(f"memory per worker: {memory_per_worker}")
logger_if_able(f"cpu count: {cpu_count}", logger, "INFO")
logger_if_able(f"memory: {sys_memory}", logger, "INFO")
logger_if_able(f"memory per run: {memory_per_run}", logger, "INFO")
logger_if_able(f"n_workers: {total_workers}", logger, "INFO")
logger_if_able(f"threads_per_worker: {total_threads}", logger, "INFO")
logger_if_able(f"memory per worker: {memory_per_worker}", logger, "INFO")

client = Client(
n_workers=n_workers,
threads_per_worker=threads_per_worker,
results = []

with Client(
n_workers=total_workers,
threads_per_worker=total_threads,
memory_limit=f"{memory_per_worker}GiB",
**kwargs,
)
# client = LocalCluster()

# LocalCluster()

# if logger is not None:
# print(f"logger name: {logger.name}")
# logger.info(f"Forwarding logging to dask client")
# client.forward_logging(logger.name, level=logging.INFO)

if logger is not None:
logger.info(f"Created dask client")
logger.info(f"Client: {client}")
else:
print(f"Created dask client")
print(f"Client: {client}")
) as client:

lazy_results = []
for args in func_arguments:
lazy_result = delayed(func, pure=True)(*args)
lazy_results.append(lazy_result)
logger_if_able(f"client: {client}", logger, "INFO")

futures = client.compute(lazy_results)
lazy_results = []
for args in func_arguments:
lazy_result = delayed(func, pure=True)(*args)
lazy_results.append(lazy_result)

results = client.gather(futures)
futures = client.compute(lazy_results)

client.close()
results = client.gather(futures)

return results

Expand Down

0 comments on commit dc4cba5

Please sign in to comment.