Skip to content

Commit

Permalink
estuary-cdk: remove global Logger and adopt ordering convention
Browse files Browse the repository at this point in the history
Establish a convention that `log: Logger` is the first parameter.
We're going to be threading these through everywhere -- which is
desireable, because it gives us a tightly-scoped structured log context
that tells us as much as possible about the surrounding task -- so let's
standardize how it should be passed so we don't have to think hard about
it.

Also refactor `http` module to clarify APIs which are stable, vs
portions that are very likely to be refactored.

A few other code-review cleanups as well.
  • Loading branch information
jgraettinger committed Mar 4, 2024
1 parent cf23725 commit e8e7b6b
Show file tree
Hide file tree
Showing 13 changed files with 277 additions and 233 deletions.
10 changes: 1 addition & 9 deletions .github/actions/setup/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,4 @@ runs:
credentials_json: ${{ inputs.gcp_service_account_key }}

- name: Set up GCloud SDK
uses: google-github-actions/setup-gcloud@v2

# - name: Set up Cloud SDK
# if: ${{ inputs.gcp_project_id }}
# uses: google-github-actions/setup-gcloud@v0
# with:
# project_id: ${{ inputs.gcp_project_id }}
# service_account_key: ${{ inputs.gcp_service_account_key }}
# export_default_credentials: true
uses: google-github-actions/setup-gcloud@v2
34 changes: 20 additions & 14 deletions estuary-cdk/estuary_cdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from .logger import init_logger
from .flow import ValidationError

logger = init_logger()

# Request type served by this connector.
Request = TypeVar("Request", bound=BaseModel)

Expand All @@ -34,8 +32,8 @@ async def stdin_jsonl(cls: type[Request]) -> AsyncGenerator[Request, None]:
# they're entered prior to handling any requests, and exited after all requests have
# been fully processed.
class Mixin(abc.ABC):
async def _mixin_enter(self, logger: Logger): ...
async def _mixin_exit(self, logger: Logger): ...
async def _mixin_enter(self, log: Logger): ...
async def _mixin_exit(self, log: Logger): ...


@dataclass
Expand All @@ -46,6 +44,7 @@ class Stopped(Exception):
`error`, if set, is a fatal error condition that caused the connector to exit.
"""

error: str | None


Expand All @@ -64,7 +63,7 @@ def request_class(cls) -> type[Request]:
raise NotImplementedError()

@abc.abstractmethod
async def handle(self, request: Request, logger: Logger) -> None:
async def handle(self, log: Logger, request: Request) -> None:
raise NotImplementedError()

# Serve this connector by invoking `handle()` for all incoming instances of
Expand All @@ -75,11 +74,15 @@ async def handle(self, request: Request, logger: Logger) -> None:
# exit code which indicates whether an error occurred.
async def serve(
self,
log: Logger | None = None,
requests: Callable[
[type[Request]], AsyncGenerator[Request, None]
] = stdin_jsonl,
logger: Logger = logger,
):
if not log:
log = init_logger()

assert isinstance(log, Logger) # Narrow type to non-None.

loop = asyncio.get_running_loop()
this_task = asyncio.current_task(loop)
Expand All @@ -100,7 +103,7 @@ def dump_all_tasks(signum, frame):
msg, args = type(exc).__name__, exc.args
if len(args) != 0:
msg = f"{msg}: {args[0]}"
logger.exception(msg, args)
log.exception(msg, args)

# We manually injected an exception into the coroutine,
# so the asyncio event loop will attempt to await it again
Expand All @@ -117,35 +120,38 @@ def dump_all_tasks(signum, frame):
# Call _mixin_enter() on all mixed-in base classes.
for base in self.__class__.__bases__:
if enter := getattr(base, "_mixin_enter", None):
await enter(self, logger)
await enter(self, log)

failed = False
try:
async with asyncio.TaskGroup() as group:
async for request in requests(self.request_class()):
group.create_task(self.handle(request, logger))
group.create_task(self.handle(log, request))

except ExceptionGroup as exc_group:
for exc in exc_group.exceptions:
if isinstance(exc, ValidationError):
if len(exc.errors) == 1:
logger.error(exc.errors[0])
log.error(exc.errors[0])
else:
logger.error("Multiple validation errors:\n - " + "\n - ".join(exc.errors))
log.error(
"Multiple validation errors:\n - "
+ "\n - ".join(exc.errors)
)
failed = True
elif isinstance(exc, Stopped):
if exc.error:
logger.error(f"{exc.error}")
log.error(f"{exc.error}")
failed = True
else:
logger.error("".join(traceback.format_exception(exc)))
log.error("".join(traceback.format_exception(exc)))
failed = True

finally:
# Call _mixin_exit() on all mixed-in base classes, in reverse order.
for base in reversed(self.__class__.__bases__):
if exit := getattr(base, "_mixin_exit", None):
await exit(self, logger)
await exit(self, log)

# Restore the original signal handler
signal.signal(signal.SIGQUIT, original_sigquit)
Expand Down
40 changes: 20 additions & 20 deletions estuary-cdk/estuary_cdk/capture/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dataclasses import dataclass
from pydantic import Field
from typing import Generic, Awaitable, Any, BinaryIO, Callable
from logging import Logger
import abc
import asyncio
import logging
import shutil
import sys
import tempfile
Expand Down Expand Up @@ -57,8 +57,8 @@ class Task:
Task also facilitates logging and graceful stop of a capture coroutine.
"""

logger: logging.Logger
"""Attached Logger of this Task instance, to use for non-protocol logging."""
log: Logger
"""Attached Logger of this Task instance, to use for scoped logging."""

@dataclass
class Stopping:
Expand Down Expand Up @@ -89,7 +89,7 @@ class Stopping:

def __init__(
self,
logger: logging.Logger,
log: Logger,
name: str,
output: BinaryIO,
stopping: Stopping,
Expand All @@ -100,7 +100,7 @@ def __init__(
self._name = name
self._output = output
self._tg = tg
self.logger = logger
self.log = log
self.stopping = stopping

def captured(self, binding: int, document: Any):
Expand Down Expand Up @@ -148,21 +148,21 @@ def spawn_child(
"""

child_name = f"{self._name}.{name_suffix}"
child_logger = self.logger.getChild(name_suffix)
child_log = self.log.getChild(name_suffix)

async def run_task(parent: Task):
async with asyncio.TaskGroup() as child_tg:
try:
t = Task(
child_logger,
child_log,
child_name,
parent._output,
parent.stopping,
child_tg,
)
await child(t)
except Exception as exc:
child_logger.error("".join(traceback.format_exception(exc)))
child_log.error("".join(traceback.format_exception(exc)))

if parent.stopping.first_error is None:
parent.stopping.first_error = exc
Expand Down Expand Up @@ -192,37 +192,37 @@ class BaseCaptureConnector(
output: BinaryIO = sys.stdout.buffer

@abc.abstractmethod
async def spec(self, _: request.Spec, logger: logging.Logger) -> ConnectorSpec:
async def spec(self, log: Logger, _: request.Spec) -> ConnectorSpec:
raise NotImplementedError()

@abc.abstractmethod
async def discover(
self,
log: Logger,
discover: request.Discover[EndpointConfig],
logger: logging.Logger,
) -> response.Discovered[ResourceConfig]:
raise NotImplementedError()

@abc.abstractmethod
async def validate(
self,
log: Logger,
validate: request.Validate[EndpointConfig, ResourceConfig],
logger: logging.Logger,
) -> response.Validated:
raise NotImplementedError()

async def apply(
self,
log: Logger,
apply: request.Apply[EndpointConfig, ResourceConfig],
logger: logging.Logger,
) -> response.Applied:
return response.Applied(actionDescription="")

@abc.abstractmethod
async def open(
self,
log: Logger,
open: request.Open[EndpointConfig, ResourceConfig, ConnectorState],
logger: logging.Logger,
) -> tuple[response.Opened, Callable[[Task], Awaitable[None]]]:
raise NotImplementedError()

Expand All @@ -231,26 +231,26 @@ async def acknowledge(self, acknowledge: request.Acknowledge) -> None:

async def handle(
self,
log: Logger,
request: Request[EndpointConfig, ResourceConfig, ConnectorState],
logger: logging.Logger,
) -> None:

if spec := request.spec:
response = await self.spec(spec, logger)
response = await self.spec(log, spec)
response.protocol = 3032023
self._emit(Response(spec=response))

elif discover := request.discover:
self._emit(Response(discovered=await self.discover(discover, logger)))
self._emit(Response(discovered=await self.discover(log, discover)))

elif validate := request.validate_:
self._emit(Response(validated=await self.validate(validate, logger)))
self._emit(Response(validated=await self.validate(log, validate)))

elif apply := request.apply:
self._emit(Response(applied=await self.apply(apply, logger)))
self._emit(Response(applied=await self.apply(log, apply)))

elif open := request.open:
opened, capture = await self.open(open, logger)
opened, capture = await self.open(log, open)
self._emit(Response(opened=opened))

stopping = Task.Stopping(asyncio.Event())
Expand All @@ -267,7 +267,7 @@ async def stop_on_elapsed_interval(interval: int) -> None:
async with asyncio.TaskGroup() as tg:

task = Task(
logger.getChild("capture"),
log.getChild("capture"),
"capture",
self.output,
stopping,
Expand Down
26 changes: 13 additions & 13 deletions estuary-cdk/estuary_cdk/capture/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class ConnectorState(BaseModel, Generic[_BaseResourceState], extra="forbid"):
"""

FetchPageFn = Callable[
[PageCursor, LogCursor, Logger],
[Logger, PageCursor, LogCursor],
Awaitable[tuple[Iterable[_BaseDocument], PageCursor]],
]
"""
Expand All @@ -192,7 +192,7 @@ class ConnectorState(BaseModel, Generic[_BaseResourceState], extra="forbid"):
"""

FetchChangesFn = Callable[
[LogCursor, Logger],
[Logger, LogCursor],
AsyncGenerator[_BaseDocument | LogCursor, None],
]
"""
Expand Down Expand Up @@ -455,7 +455,7 @@ async def _binding_snapshot_task(
next_sync = state.updated_at + binding.resourceConfig.interval
sleep_for = next_sync - datetime.now(tz=UTC)

task.logger.debug(
task.log.debug(
"awaiting next snapshot",
{"sleep_for": sleep_for, "next": next_sync},
)
Expand All @@ -466,22 +466,22 @@ async def _binding_snapshot_task(
task.stopping.event.wait(), timeout=sleep_for.total_seconds()
)

task.logger.debug(f"periodic snapshot is idle and is yielding to stop")
task.log.debug(f"periodic snapshot is idle and is yielding to stop")
return
except asyncio.TimeoutError:
# `sleep_for` elapsed.
state.updated_at = datetime.now(tz=UTC)

count = 0
async for doc in fetch_snapshot(task.logger):
async for doc in fetch_snapshot(task.log):
doc.meta_ = BaseDocument.Meta(
op="u" if count < state.last_count else "c", row_id=count
)
task.captured(binding_index, doc)
count += 1

digest = task.pending_digest()
task.logger.debug(
task.log.debug(
"polled snapshot",
{
"count": count,
Expand Down Expand Up @@ -517,12 +517,12 @@ async def _binding_backfill_task(
)

if state.next_page:
task.logger.info(f"resuming backfill", state)
task.log.info(f"resuming backfill", state)
else:
task.logger.info(f"beginning backfill", state)
task.log.info(f"beginning backfill", state)

while True:
page, next_page = await fetch_page(state.next_page, state.cutoff, task.logger)
page, next_page = await fetch_page(task.log, state.next_page, state.cutoff)
for doc in page:
task.captured(binding_index, doc)

Expand All @@ -536,7 +536,7 @@ async def _binding_backfill_task(
bindingStateV1={binding.stateKey: ResourceState(backfill=None)}
)
task.checkpoint(connector_state)
task.logger.info(f"completed backfill")
task.log.info(f"completed backfill")


async def _binding_incremental_task(
Expand All @@ -549,14 +549,14 @@ async def _binding_incremental_task(
connector_state = ConnectorState(
bindingStateV1={binding.stateKey: ResourceState(inc=state)}
)
task.logger.info(f"resuming incremental replication", state)
task.log.info(f"resuming incremental replication", state)

while True:

checkpoints = 0
pending = False

async for item in fetch_changes(state.cursor, task.logger):
async for item in fetch_changes(task.log, state.cursor):
if isinstance(item, BaseDocument):
task.captured(binding_index, item)
pending = True
Expand Down Expand Up @@ -599,7 +599,7 @@ async def _binding_incremental_task(
timeout=binding.resourceConfig.interval.total_seconds(),
)

task.logger.debug(
task.log.debug(
f"incremental replication is idle and is yielding to stop"
)
return
Expand Down
Loading

0 comments on commit e8e7b6b

Please sign in to comment.