Skip to content

Commit

Permalink
Step control flow without passing loopback cmds mutable reference
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomperez98 committed Dec 21, 2024
1 parent 584e0ef commit 549e32f
Showing 1 changed file with 45 additions and 30 deletions.
75 changes: 45 additions & 30 deletions src/resonate/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,19 +175,23 @@ def _loop(self) -> None:
break

# start the next tick
for loopback in self._tick(cmd):
for loopback in self._step(cmd):
self._cmd_queue.put(loopback)
assert not self._runnable

# mark task done
self._cmd_queue.task_done()

def _tick(self, cmd: Command) -> list[Command]: # noqa: C901,PLR0912
def _step(self, cmd: Command) -> list[Command]: # noqa: C901,PLR0912
loopback_cmds: list[Command] = []
if isinstance(cmd, Invoke):
self._ingest(cmd.id)
elif isinstance(cmd, Resume):
self._add_to_runnable(cmd.id, cmd.result)
elif isinstance(cmd, Complete):
self._process_final_value(cmd.id, cmd.result, loopback_cmds)
resume_cmds = self._process_final_value(cmd.id, cmd.result)
if resume_cmds:
loopback_cmds.extend(resume_cmds)
elif isinstance(cmd, Claim):
assert isinstance(self._store, RemoteStore)
invoke_or_resume = self._store.tasks.claim(
Expand All @@ -197,13 +201,20 @@ def _tick(self, cmd: Command) -> list[Command]: # noqa: C901,PLR0912
ttl=5 * 1000,
)
if isinstance(invoke_or_resume, InvokeMsg):
self._process_invoke_msg(invoke_or_resume, cmd.record, loopback_cmds)
invoke_cmd = self._process_invoke_msg(invoke_or_resume, cmd.record)
loopback_cmds.append(invoke_cmd)
elif isinstance(invoke_or_resume, ResumeMsg):
self._process_resume_msg(invoke_or_resume, cmd.record, loopback_cmds)
invoke_or_resume_cmds = self._process_resume_msg(
invoke_or_resume, cmd.record
)
if isinstance(invoke_or_resume_cmds, Invoke):
loopback_cmds.append(invoke_or_resume_cmds)
elif isinstance(invoke_or_resume_cmds, list):
loopback_cmds.extend(invoke_or_resume_cmds)
else:
assert_never(invoke_or_resume_cmds)
else:
assert_never(invoke_or_resume)
elif isinstance(cmd, Resume):
self._add_to_runnable(cmd.id, cmd.result)
else:
assert_never(cmd)

Expand All @@ -226,7 +237,10 @@ def _tick(self, cmd: Command) -> list[Command]: # noqa: C901,PLR0912
elif isinstance(yielded_value, Promise):
self._process_promise(record, yielded_value)
elif isinstance(yielded_value, FinalValue):
self._process_final_value(record, yielded_value.v, loopback_cmds)
resume_cmds = self._process_final_value(record, yielded_value.v)
if resume_cmds:
loopback_cmds.extend(resume_cmds)

elif isinstance(yielded_value, DI):
# start execution from the top. Add current record to runnable
raise NotImplementedError
Expand All @@ -235,9 +249,7 @@ def _tick(self, cmd: Command) -> list[Command]: # noqa: C901,PLR0912

return loopback_cmds

def _process_invoke_msg(
self, invoke_msg: InvokeMsg, task: TaskRecord, loopback_cmds: list[Command]
) -> None:
def _process_invoke_msg(self, invoke_msg: InvokeMsg, task: TaskRecord) -> Invoke:
logger.info(
"Invoke message for %s received", invoke_msg.root_durable_promise.id
)
Expand Down Expand Up @@ -269,11 +281,11 @@ def _process_invoke_msg(
record.add_durable_promise(invoke_msg.root_durable_promise)

record.add_task(task=task)
loopback_cmds.append(Invoke(record.id))
return Invoke(record.id)

def _process_resume_msg(
self, resume_msg: ResumeMsg, task: TaskRecord, loopback_cmds: list[Command]
) -> None:
self, resume_msg: ResumeMsg, task: TaskRecord
) -> Invoke | list[Resume]:
logger.info(
"Resume message for %s received", resume_msg.leaf_durable_promise.id
)
Expand All @@ -292,12 +304,11 @@ def _process_resume_msg(
deduping=True,
)

self._unblock_awaiting_remote(leaf_record.id, loopback_cmds)
return self._unblock_awaiting_remote(leaf_record.id)

else:
self._process_invoke_msg(
InvokeMsg(resume_msg.root_durable_promise), task, loopback_cmds
)
return self._process_invoke_msg(
InvokeMsg(resume_msg.root_durable_promise), task
)

def _process_promise(self, record: Record[Any], promise: Promise[Any]) -> None:
promise_record = self._records[promise.id]
Expand Down Expand Up @@ -338,26 +349,30 @@ def _process_promise(self, record: Record[Any], promise: Promise[Any]) -> None:
else:
assert_never(promise_record.invocation)

def _unblock_awaiting_local(self, id: str, loopback_cmds: list[Command]) -> None:
def _unblock_awaiting_local(self, id: str) -> list[Resume]:
record = self._records[id]
assert record.done()
resume_cmds: list[Resume] = []
for blocked_id in self._awaiting_lfi.pop(record.id, []):
blocked_record = self._records[blocked_id]
assert blocked_record.blocked_on
assert blocked_record.blocked_on.id == id
blocked_record.blocked_on = None
loopback_cmds.append(Resume(blocked_id, record.safe_result()))
resume_cmds.append(Resume(blocked_id, record.safe_result()))
return resume_cmds

def _unblock_awaiting_remote(self, id: str, loopback_cmds: list[Command]) -> None:
def _unblock_awaiting_remote(self, id: str) -> list[Resume]:
record = self._records[id]
assert record.done()
assert isinstance(record.invocation, RFI)
resume_cmds: list[Resume] = []
for blocked_id in self._awaiting_rfi.pop(record.id, []):
blocked_record = self._records[blocked_id]
assert blocked_record.blocked_on
assert blocked_record.blocked_on.id == id
blocked_record.blocked_on = None
loopback_cmds.append(Resume(blocked_id, record.safe_result()))
resume_cmds.append(Resume(blocked_id, record.safe_result()))
return resume_cmds

def _add_to_runnable(
self, id: str, next_value: Result[Any, Exception] | None
Expand Down Expand Up @@ -647,15 +662,15 @@ def _process_final_value(
self,
record: str | Record[Any],
final_value: Result[Any, Exception],
loopback_cmds: list[Command],
) -> None:
) -> list[Resume] | None:
if isinstance(record, str):
record = self._records[record]
if record.should_retry(final_value):
record.increate_attempt()
self._delay_queue.enqueue(Invoke(record.id), record.next_retry_delay())
return None

elif record.invocation.opts.durable:
if record.invocation.opts.durable:
durable_promise: DurablePromiseRecord
if isinstance(final_value, Ok):
durable_promise = self._store.promises.resolve(
Expand Down Expand Up @@ -684,15 +699,15 @@ def _process_final_value(
self._complete_task(record.id)

record.set_result(final_value, deduping=False)
self._unblock_awaiting_local(record.id, loopback_cmds)
resume_cmds = self._unblock_awaiting_local(record.id)

root = record.root()
if root != record and self._blocked_only_on_remote(root.id):
self._complete_task(root.id)
return resume_cmds

else:
record.set_result(final_value, deduping=False)
self._unblock_awaiting_local(record.id, loopback_cmds)
record.set_result(final_value, deduping=False)
return self._unblock_awaiting_local(record.id)

def _get_info_from_rfi(self, rfi: RFI) -> tuple[Data, Headers, Tags, int | None]:
data: Data
Expand Down

0 comments on commit 549e32f

Please sign in to comment.