Skip to content

Commit

Permalink
Loopback cmds and resume implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomperez98 committed Dec 20, 2024
1 parent a649c5f commit d1424af
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 27 deletions.
1 change: 1 addition & 0 deletions src/resonate/cmd_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Invoke:
@dataclass(frozen=True)
class Resume:
id: str
result: Result[Any, Exception]


@dataclass(frozen=True)
Expand Down
60 changes: 33 additions & 27 deletions src/resonate/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,25 +168,22 @@ def _heartbeat(self) -> None:
time.sleep(2)

def _loop(self) -> None:
# wait until an event occurs, either:
# - resonate run is called
# - a completion is enqueued by the processor
# - a task is enqueued by the task source
while True:
# immediately clear the event so the next tick waits
# unless another event occurs in the meantime
cmd = self._cmd_queue.dequeue()
if cmd is None:
break

# start the next tick
self._tick(cmd)
for loopback in self._tick(cmd):
self._cmd_queue.enqueue(loopback)
assert not self._runnable

def _tick(self, cmd: Command) -> None: # noqa: C901,PLR0912
def _tick(self, cmd: Command) -> list[Command]: # noqa: C901,PLR0912
loopback_cmds: list[Command] = []
if isinstance(cmd, Invoke):
self._ingest(cmd.id)
elif isinstance(cmd, Complete):
self._process_final_value(cmd.id, cmd.result)
self._process_final_value(cmd.id, cmd.result, loopback_cmds)
elif isinstance(cmd, Claim):
assert isinstance(self._store, RemoteStore)
invoke_or_resume = self._store.tasks.claim(
Expand All @@ -196,13 +193,13 @@ def _tick(self, cmd: Command) -> None: # noqa: C901,PLR0912
ttl=5 * 1000,
)
if isinstance(invoke_or_resume, InvokeMsg):
self._process_invoke_msg(invoke_or_resume, cmd.record)
self._process_invoke_msg(invoke_or_resume, cmd.record, loopback_cmds)
elif isinstance(invoke_or_resume, ResumeMsg):
self._process_resume_msg(invoke_or_resume, cmd.record)
self._process_resume_msg(invoke_or_resume, cmd.record, loopback_cmds)
else:
assert_never(invoke_or_resume)
elif isinstance(cmd, Resume):
raise NotImplementedError
self._add_to_runnable(cmd.id, cmd.result)
else:
assert_never(cmd)

Expand All @@ -225,14 +222,18 @@ def _tick(self, cmd: Command) -> None: # noqa: C901,PLR0912
elif isinstance(yielded_value, Promise):
self._process_promise(record, yielded_value)
elif isinstance(yielded_value, FinalValue):
self._process_final_value(record, yielded_value.v)
self._process_final_value(record, yielded_value.v, loopback_cmds)
elif isinstance(yielded_value, DI):
# start execution from the top. Add current record to runnable
raise NotImplementedError
else:
assert_never(yielded_value)

def _process_invoke_msg(self, invoke_msg: InvokeMsg, task: TaskRecord) -> None:
return loopback_cmds

def _process_invoke_msg(
self, invoke_msg: InvokeMsg, task: TaskRecord, loopback_cmds: list[Command]
) -> None:
logger.info(
"Invoke message for %s received", invoke_msg.root_durable_promise.id
)
Expand Down Expand Up @@ -264,9 +265,11 @@ def _process_invoke_msg(self, invoke_msg: InvokeMsg, task: TaskRecord) -> None:
record.add_durable_promise(invoke_msg.root_durable_promise)

record.add_task(task=task)
self._cmd_queue.enqueue(Invoke(record.id))
loopback_cmds.append(Invoke(record.id))

def _process_resume_msg(self, resume_msg: ResumeMsg, task: TaskRecord) -> None:
def _process_resume_msg(
self, resume_msg: ResumeMsg, task: TaskRecord, loopback_cmds: list[Command]
) -> None:
logger.info(
"Resume message for %s received", resume_msg.leaf_durable_promise.id
)
Expand All @@ -285,10 +288,12 @@ def _process_resume_msg(self, resume_msg: ResumeMsg, task: TaskRecord) -> None:
deduping=True,
)

self._unblock_awaiting_remote(leaf_record.id)
self._unblock_awaiting_remote(leaf_record.id, loopback_cmds)

else:
self._process_invoke_msg(InvokeMsg(resume_msg.root_durable_promise), task)
self._process_invoke_msg(
InvokeMsg(resume_msg.root_durable_promise), task, loopback_cmds
)

def _process_promise(self, record: Record[Any], promise: Promise[Any]) -> None:
promise_record = self._records[promise.id]
Expand Down Expand Up @@ -329,18 +334,17 @@ def _process_promise(self, record: Record[Any], promise: Promise[Any]) -> None:
else:
assert_never(promise_record.invocation)

def _unblock_awaiting_local(self, id: str) -> None:
def _unblock_awaiting_local(self, id: str, loopback_cmds: list[Command]) -> None:
record = self._records[id]
assert record.done()
for blocked_id in self._awaiting_lfi.pop(record.id, []):
blocked_record = self._records[blocked_id]
assert blocked_record.blocked_on
assert blocked_record.blocked_on.id == id
blocked_record.blocked_on = None
logger.info("Unblocking %s. Who has blocked locally on %s", blocked_id, id)
self._add_to_runnable(blocked_id, next_value=record.safe_result())
loopback_cmds.append(Resume(blocked_id, record.safe_result()))

def _unblock_awaiting_remote(self, id: str) -> None:
def _unblock_awaiting_remote(self, id: str, loopback_cmds: list[Command]) -> None:
record = self._records[id]
assert record.done()
assert isinstance(record.invocation, RFI)
Expand All @@ -349,8 +353,7 @@ def _unblock_awaiting_remote(self, id: str) -> None:
assert blocked_record.blocked_on
assert blocked_record.blocked_on.id == id
blocked_record.blocked_on = None
logger.info("Unblocking %s. Who has blocked remotely on %s", blocked_id, id)
self._add_to_runnable(blocked_id, next_value=record.safe_result())
loopback_cmds.append(Resume(blocked_id, record.safe_result()))

def _add_to_runnable(
self, id: str, next_value: Result[Any, Exception] | None
Expand Down Expand Up @@ -637,7 +640,10 @@ def _process_lfi(self, record: Record[Any], lfi: LFI) -> None:
self._add_to_runnable(record.id, Ok(child_record.promise))

def _process_final_value(
self, record: str | Record[Any], final_value: Result[Any, Exception]
self,
record: str | Record[Any],
final_value: Result[Any, Exception],
loopback_cmds: list[Command],
) -> None:
if isinstance(record, str):
record = self._records[record]
Expand Down Expand Up @@ -674,15 +680,15 @@ def _process_final_value(
self._complete_task(record.id)

record.set_result(final_value, deduping=False)
self._unblock_awaiting_local(record.id)
self._unblock_awaiting_local(record.id, loopback_cmds)

root = record.root()
if root != record and self._blocked_only_on_remote(root.id):
self._complete_task(root.id)

else:
record.set_result(final_value, deduping=False)
self._unblock_awaiting_local(record.id)
self._unblock_awaiting_local(record.id, loopback_cmds)

def _get_info_from_rfi(self, rfi: RFI) -> tuple[Data, Headers, Tags, int | None]:
data: Data
Expand Down

0 comments on commit d1424af

Please sign in to comment.