Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: Early Exit 2 - sdf.collect(...).to_pandas() #729

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 66 additions & 7 deletions quixstreams/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,52 @@ def __call__(
) -> TopicManager: ...


class ExitManager:
def __init__(self):
self.should_exit = lambda: False

def configure(self, number: Optional[int] = None, timeout: Optional[int] = None):
"""
Configure the exit manager with either a message count limit, timeout, or both.

The `should_exit` function will be called for every message processed, so it needs
to be highly optimized. To ensure optimal performance, the logic for determining
which checks to perform is handled in `__init__` rather than in `should_exit` itself.

Args:
number: Maximum number of messages to process before exiting
timeout: Maximum time in seconds to run before exiting
"""
if number is not None and timeout is not None:
self._number = number
self._timeout = timeout
self._counter = 0
self._start = time.monotonic()
self.should_exit = self._check_number_and_timeout
elif number is not None:
self._number = number
self._counter = 0
self.should_exit = self._check_number
elif timeout is not None:
self._timeout = timeout
self._start = time.monotonic()
self.should_exit = self._check_timeout
else:
self.should_exit = lambda: False

def _check_number(self):
if self._counter >= self._number:
return True
self._counter += 1
return False

def _check_timeout(self):
return (time.monotonic() - self._start) > self._timeout

def _check_number_and_timeout(self):
return self._check_number() or self._check_timeout()


class Application:
"""
The main Application class.
Expand Down Expand Up @@ -303,17 +349,22 @@ def __init__(

self._on_message_processed = on_message_processed
self._on_processing_error = on_processing_error or default_on_processing_error
self._on_consumer_error = on_consumer_error
self._on_producer_error = on_producer_error
self._topic_manager = topic_manager or self._get_topic_manager()
self._dataframe_registry = DataframeRegistry()
self.exit_manager = ExitManager()
self.reset()

def reset(self):
self._consumer = self._get_rowconsumer(
on_error=on_consumer_error,
on_error=self._on_consumer_error,
extra_config_overrides=consumer_extra_config_overrides,
)
self._producer = self._get_rowproducer(on_error=on_producer_error)
self._producer = self._get_rowproducer(on_error=self._on_producer_error)
self._running = False
self._failed = False

self._topic_manager = topic_manager or self._get_topic_manager()

producer = None
recovery_manager = None
if self._config.use_changelog_topics:
Expand Down Expand Up @@ -344,7 +395,6 @@ def __init__(
sink_manager=self._sink_manager,
pausing_manager=self._pausing_manager,
)
self._dataframe_registry = DataframeRegistry()

@property
def config(self) -> "ApplicationConfig":
Expand Down Expand Up @@ -518,6 +568,7 @@ def dataframe(
topic_manager=self._topic_manager,
processing_context=self._processing_context,
registry=self._dataframe_registry,
app=self,
)
self._dataframe_registry.register_root(sdf)

Expand Down Expand Up @@ -546,6 +597,14 @@ def stop(self, fail: bool = False):
if self._state_manager.using_changelogs:
self._state_manager.stop_recovery()

def running(self):
if not self._running:
return False
elif self.exit_manager.should_exit():
self.stop()
return False
return True

def _get_rowproducer(
self,
on_error: Optional[ProducerErrorCallback] = None,
Expand Down Expand Up @@ -803,7 +862,7 @@ def _run_dataframe(self):

dataframes_composed = self._dataframe_registry.compose_all()

while self._running:
while self.running():
if self._state_manager.recovery_required:
self._state_manager.do_recovery()
else:
Expand All @@ -818,7 +877,7 @@ def _run_dataframe(self):
def _run_sources(self):
self._running = True
self._source_manager.start_sources()
while self._running:
while self.running():
self._source_manager.raise_for_error()

if not self._source_manager.is_alive():
Expand Down
55 changes: 55 additions & 0 deletions quixstreams/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pprint
from datetime import timedelta
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand Down Expand Up @@ -59,6 +60,9 @@
TumblingWindowDefinition,
)

if TYPE_CHECKING:
from quixstreams.app import Application

ApplyCallbackStateful = Callable[[Any, State], Any]
ApplyWithMetadataCallbackStateful = Callable[[Any, Any, int, Any, State], Any]
UpdateCallbackStateful = Callable[[Any, State], None]
Expand Down Expand Up @@ -120,6 +124,7 @@ def __init__(
topic_manager: TopicManager,
registry: DataframeRegistry,
processing_context: ProcessingContext,
app: Application,
stream: Optional[Stream] = None,
):
self._stream: Stream = stream or Stream()
Expand All @@ -129,6 +134,9 @@ def __init__(
self._processing_context = processing_context
self._producer = processing_context.producer
self._locked = False
self._app = app
self._collection: list[Any] = []
self._collection_active = False

@property
def processing_context(self) -> ProcessingContext:
Expand Down Expand Up @@ -767,6 +775,52 @@ def print(self, pretty: bool = True, metadata: bool = False) -> Self:
metadata=metadata,
)

def collect(
self,
number: Optional[int] = None,
timeout: Optional[int] = None,
):
if number is None and timeout is None:
raise ValueError(
"Either number or timeout must be provided. "
"Otherwise Application will run forever."
)
self._app.exit_manager.configure(number, timeout)
self._collection = []

if not self._collection_active:

def _collect(value):
nonlocal self
self._collection.append(value)

self._collection_active = True
self = self._add_update(_collect, metadata=False)

self._app.run()
self._app.reset()
return self

def to_pandas(self):
try:
import pandas as pd
except ImportError:
raise ImportError(
"Pandas is not installed. "
"Run `pip install quixstreams[pandas]` to install it."
)
return pd.DataFrame(self._collection)

def to_polars(self):
try:
import polars as pl
except ImportError:
raise ImportError(
"Polars is not installed. "
"Run `pip install quixstreams[polars]` to install it."
)
return pl.DataFrame(self._collection)

def compose(
self,
sink: Optional[VoidExecutor] = None,
Expand Down Expand Up @@ -1230,6 +1284,7 @@ def __dataframe_clone__(
processing_context=self._processing_context,
topic_manager=self._topic_manager,
registry=self._registry,
app=self._app,
)
return clone

Expand Down
2 changes: 2 additions & 0 deletions tests/test_quixstreams/test_dataframe/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from quixstreams.app import Application
from quixstreams.dataframe.dataframe import StreamingDataFrame
from quixstreams.dataframe.registry import DataframeRegistry
from quixstreams.models.topics import Topic, TopicManager
Expand Down Expand Up @@ -48,6 +49,7 @@ def factory(
topic_manager=topic_manager,
registry=registry,
processing_context=processing_ctx,
app=MagicMock(spec_set=Application),
)
registry.register_root(sdf)
return sdf
Expand Down
Loading