Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reworks thread execution model #1

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nextpipe/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "v0.1.0.dev4"
__version__ = "v0.1.0.dev5"
97 changes: 46 additions & 51 deletions nextpipe/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from typing import List, Optional, Union

from nextmv.cloud import Application, Client, StatusV2
from pathos.multiprocessing import ProcessingPool as Pool

from . import decorators, schema, utils
from . import decorators, schema, threads, utils


class DAGNode:
Expand Down Expand Up @@ -80,55 +79,51 @@ def run(self):

# Run the nodes in parallel
tasks = {}
with Pool(8) as pool:
while open_nodes:
while True:
# Get the first node from the open nodes which has all its predecessors done
node = next(
iter(
filter(
lambda n: all(p in closed_nodes for p in n.predecessors),
open_nodes,
)
),
None,
)
if node is None:
# No more nodes to run at this point. Wait for the remaining tasks to finish.
break
open_nodes.remove(node)
# Skip the node if it is optional and the condition is not met
if node.step.skip():
utils.log(f"Skipping node {node.step.get_name()}")
node.step.set_state("skipped")
utils.log("NEXTPIPE_DAG_UPDATE=" + self.graph._persist_dag_update(node))
continue
# Run the node asynchronously
tasks[node] = pool.apipe(
self.__run_node,
node,
self._get_inputs(node),
self.client,
)
node.step.set_state("running")
pool = threads.Pool(8)
while open_nodes:
while True:
# Get the first node from the open nodes which has all its predecessors done
node = next(
iter(
filter(
lambda n: all(p in closed_nodes for p in n.predecessors),
open_nodes,
)
),
None,
)
if node is None:
# No more nodes to run at this point. Wait for the remaining tasks to finish.
break
open_nodes.remove(node)
# Skip the node if it is optional and the condition is not met
if node.step.skip():
utils.log(f"Skipping node {node.step.get_name()}")
node.step.set_state("skipped")
utils.log("NEXTPIPE_DAG_UPDATE=" + self.graph._persist_dag_update(node))

# Wait until at least one task is done
task_done = False
while not task_done:
time.sleep(0.1)
# Check if any tasks are done, if not, keep waiting
for node, task in list(tasks.items()):
if task.ready():
# Remove task and mark successors as ready by adding them to the open list.
result = task.get()
self.set_result(node, result)
node.step.set_state("succeeded")
utils.log("NEXTPIPE_DAG_UPDATE=" + self.graph._persist_dag_update(node))
del tasks[node]
task_done = True
closed_nodes.add(node)
open_nodes.update(node.successors)
continue
# Run the node asynchronously
job = threads.Job(self.__run_node, (node, self._get_inputs(node), self.client))
pool.run(job)
tasks[node] = job
node.step.set_state("running")
utils.log("NEXTPIPE_DAG_UPDATE=" + self.graph._persist_dag_update(node))

# Wait until at least one task is done
task_done = False
while not task_done:
time.sleep(0.1)
# Check if any tasks are done, if not, keep waiting
for node, job in list(tasks.items()):
if job.done:
# Remove task and mark successors as ready by adding them to the open list.
self.set_result(node, job.result)
node.step.set_state("succeeded")
utils.log("NEXTPIPE_DAG_UPDATE=" + self.graph._persist_dag_update(node))
del tasks[node]
task_done = True
closed_nodes.add(node)
open_nodes.update(node.successors)

def set_result(self, step: callable, result: object):
self.results[step.step] = result
Expand Down Expand Up @@ -206,7 +201,7 @@ def __init__(self, flow_spec):
# Create a Mermaid diagram of the graph and log it
mermaid = self._to_mermaid()
utils.log(mermaid)
mermaid_url = f'https://mermaid.ink/svg/{base64.b64encode(mermaid.encode("utf8")).decode("ascii")}?theme=dark'
mermaid_url = f"https://mermaid.ink/svg/{base64.b64encode(mermaid.encode('utf8')).decode('ascii')}?theme=dark"
utils.log(f"Mermaid URL: {mermaid_url}")

def __create_graph(self, flow_spec):
Expand Down
75 changes: 75 additions & 0 deletions nextpipe/threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import threading
import time
from typing import Callable, Optional


class Job:
def __init__(self, target: Callable, args: Optional[tuple] = None):
self.target = target
self.args = args
self.done = False
self.result = None

def run(self):
if self.args:
self.result = self.target(*self.args)
else:
self.result = self.target()
self.done = True


class Pool:
def __init__(self, max_threads: int):
self.max_threads = max_threads
self.counter = 0 # Used to assign unique IDs to threads
self.waiting = {}
self.running = {}
self.lock = threading.Lock()
self.cond = threading.Condition(self.lock)

def run(self, job: Job) -> None:
with self.lock:
self.counter += 1
thread_id = self.counter
self.waiting[thread_id] = job

def worker(job: Job, thread_id: int):
try:
job.run()
finally:
with self.lock:
self.running.pop(thread_id, None)
self.cond.notify_all() # Notify others that a thread is available

while True:
with self.lock:
if len(self.running) < self.max_threads:
# Move job from waiting to running
thread = threading.Thread(target=worker, args=(job, thread_id))
self.running[thread_id] = thread
self.waiting.pop(thread_id, None)
thread.start()
break # Successfully assigned the job to a thread
else:
self.cond.wait() # Wait until a thread is available

def join(self) -> None:
with self.cond:
while self.waiting or self.running:
self.cond.wait() # Wait until all jobs are finished


def test_pool():
def target(*args):
print(f"Running job with args: {args}")
time.sleep(0.5) # Simulate work

pool = Pool(2)
for i in range(1, 7): # Submit 6 jobs
pool.run(Job(target, (i,)))
pool.join()
print("All jobs completed.")


if __name__ == "__main__":
test_pool()
Loading