Skip to content

Commit

Permalink
fix: make ctx._events_buffer json-serializable (#17676)
Browse files Browse the repository at this point in the history
  • Loading branch information
masci authored Jan 30, 2025
1 parent d22fdcd commit 7e396ae
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 16 deletions.
17 changes: 10 additions & 7 deletions llama-index-core/llama_index/core/workflow/context.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import asyncio
import json
import warnings
import uuid
import warnings
from collections import defaultdict
from typing import Dict, Any, Optional, List, Type, TYPE_CHECKING, Set, Tuple, TypeVar
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, TypeVar

from .context_serializers import BaseSerializer, JsonSerializer
from .decorators import StepConfig
from .events import Event
from .errors import WorkflowRuntimeError
from .events import Event

if TYPE_CHECKING: # pragma: no cover
from .workflow import Workflow
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(
self._lock = asyncio.Lock()
self._globals: Dict[str, Any] = {}
# Step-specific instance
self._events_buffer: Dict[Type[Event], List[Event]] = defaultdict(list)
self._events_buffer: Dict[str, List[Event]] = defaultdict(list)

def _serialize_queue(self, queue: asyncio.Queue, serializer: BaseSerializer) -> str:
queue_items = list(queue._queue) # type: ignore
Expand Down Expand Up @@ -240,14 +240,17 @@ def session(self) -> "Context":
warnings.warn(msg, DeprecationWarning)
return self

def _get_full_path(self, ev_type: Type[Event]) -> str:
return f"{ev_type.__module__}.{ev_type.__name__}"

def collect_events(
self, ev: Event, expected: List[Type[Event]]
) -> Optional[List[Event]]:
self._events_buffer[type(ev)].append(ev)
self._events_buffer[self._get_full_path(type(ev))].append(ev)

retval: List[Event] = []
for e_type in expected:
e_instance_list = self._events_buffer.get(e_type)
e_instance_list = self._events_buffer.get(self._get_full_path(e_type))
if e_instance_list:
retval.append(e_instance_list.pop(0))

Expand All @@ -256,7 +259,7 @@ def collect_events(

# put back the events if unable to collect all
for ev in retval:
self._events_buffer[type(ev)].append(ev)
self._events_buffer[self._get_full_path(type(ev))].append(ev)

return None

Expand Down
19 changes: 12 additions & 7 deletions llama-index-core/tests/workflow/test_context.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import asyncio
import json
from typing import Optional, Union
from unittest import mock
from typing import Union, Optional

import pytest
from llama_index.core.workflow.decorators import step
from llama_index.core.workflow.errors import WorkflowRuntimeError
from llama_index.core.workflow.events import Event, StartEvent, StopEvent
from llama_index.core.workflow.workflow import (
Workflow,
Context,
Workflow,
)
from llama_index.core.workflow.decorators import step
from llama_index.core.workflow.errors import WorkflowRuntimeError
from llama_index.core.workflow.events import StartEvent, StopEvent, Event
from llama_index.core.workflow.workflow import Workflow

from .conftest import OneTestEvent, AnotherTestEvent
from .conftest import AnotherTestEvent, OneTestEvent


@pytest.mark.asyncio()
Expand Down Expand Up @@ -115,6 +115,11 @@ def test_get_result(ctx):
assert ctx.get_result() == 42


def test_to_dict_with_events_buffer(ctx):
ctx.collect_events(OneTestEvent(), [OneTestEvent, AnotherTestEvent])
assert json.dumps(ctx.to_dict())


@pytest.mark.asyncio()
async def test_deprecated_params(ctx):
with pytest.warns(
Expand Down
17 changes: 15 additions & 2 deletions llama-index-core/tests/workflow/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ async def original_step(
ctx.session.send_event(OneTestEvent(test_param="test2"))
ctx.session.send_event(OneTestEvent(test_param="test3"))

# send one extra event
ctx.session.send_event(AnotherTestEvent(another_test_param="test4"))

return LastEvent()

@step(num_workers=3)
Expand All @@ -241,11 +244,21 @@ async def final_step(
workflow = NumWorkersWorkflow()

start_time = time.time()
result = await workflow.run()
handler = workflow.run()
result = await handler
end_time = time.time()

assert workflow.is_done()
assert set(result) == {"test1", "test2", "test3"}
assert set(result) == {"test1", "test2", "test4"}

# ctx should have 1 extra event
assert (
len(handler.ctx._events_buffer["tests.workflow.conftest.AnotherTestEvent"]) == 1
)

# ensure ctx is serializable
ctx = handler.ctx
ctx.to_dict()

# Check if the execution time is close to 1 second (with some tolerance)
execution_time = end_time - start_time
Expand Down

0 comments on commit 7e396ae

Please sign in to comment.