Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paw/record sequence no #251

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20250226-150844.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add sequence number to record/replay records and add new invocation context accessor
time: 2025-02-26T15:08:44.584623-05:00
custom:
Author: peterallenwebb
Issue: "251"
13 changes: 13 additions & 0 deletions dbt_common/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def __init__(self, env: Mapping[str, str]):
self._env_secrets: Optional[List[str]] = None
self._env_private = env_private
self.recorder: Optional[Recorder] = None

# If set to True later, this flag will prevent dbt from creating a new
# invocation context for every invocation, which is useful for testing
# scenarios.
self.do_not_reset = False

# This class will also eventually manage the invocation_id, flags, event manager, etc.

@property
Expand Down Expand Up @@ -85,3 +91,10 @@ def get_invocation_context() -> InvocationContext:
invocation_var = reliably_get_invocation_var()
ctx = invocation_var.get()
return ctx


def try_get_invocation_context() -> Optional[InvocationContext]:
try:
return get_invocation_context()
except Exception:
return None
45 changes: 34 additions & 11 deletions dbt_common/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import inspect
import json
import os
import threading

from enum import Enum
from typing import Any, Callable, Dict, List, Mapping, Optional, TextIO, Tuple, Type
Expand All @@ -31,9 +32,10 @@ class Record:
result_cls: Optional[type] = None
group: Optional[str] = None

def __init__(self, params, result) -> None:
def __init__(self, params, result, seq=None) -> None:
self.params = params
self.result = result
self.seq = seq

def to_dict(self) -> Dict[str, Any]:
return {
Expand All @@ -45,6 +47,7 @@ def to_dict(self) -> Dict[str, Any]:
else dataclasses.asdict(self.result)
if self.result is not None
else None,
"seq": self.seq,
}

@classmethod
Expand All @@ -61,7 +64,8 @@ def from_dict(cls, dct: Mapping) -> "Record":
if cls.result_cls is not None
else None
)
return cls(params=p, result=r)
s = dct.get("seq", None)
return cls(params=p, result=r, seq=s)


class Diff:
Expand Down Expand Up @@ -167,6 +171,9 @@ def __init__(
if self.mode == RecorderMode.REPLAY:
self._unprocessed_records_by_type = self.load(self.previous_recording_path)

self._counter = 0
self._counter_lock = threading.Lock()

@classmethod
def register_record_type(cls, rec_type) -> Any:
cls._record_cls_by_name[rec_type.__name__] = rec_type
Expand All @@ -177,6 +184,11 @@ def add_record(self, record: Record) -> None:
rec_cls_name = record.__class__.__name__ # type: ignore
if rec_cls_name not in self._records_by_type:
self._records_by_type[rec_cls_name] = []

with self._counter_lock:
record.seq = self._counter
self._counter += 1

self._records_by_type[rec_cls_name].append(record)

def pop_matching_record(self, params: Any) -> Optional[Record]:
Expand All @@ -199,21 +211,27 @@ def pop_matching_record(self, params: Any) -> Optional[Record]:
return match

def write_json(self, out_stream: TextIO):
d = self._to_dict()
d = self._to_list()
json.dump(d, out_stream)

def write(self) -> None:
with open(self.current_recording_path, "w") as file:
self.write_json(file)

def _to_dict(self) -> Dict:
dct: Dict[str, Any] = {}
def _to_list(self) -> List[Dict]:

def get_tagged_dict(record: Record, record_type: str) -> Dict :
d = record.to_dict()
d["type"] = record_type
return d

record_list: List[Dict] = []
for record_type in self._records_by_type:
record_list = [r.to_dict() for r in self._records_by_type[record_type]]
dct[record_type] = record_list
record_list.extend(get_tagged_dict(r, record_type) for r in self._records_by_type[record_type])

record_list.sort(key=lambda r: r["seq"])

return dct
return record_list

@classmethod
def load(cls, file_name: str) -> Dict[str, List[Dict[str, Any]]]:
Expand Down Expand Up @@ -458,9 +476,14 @@ def record_replay_wrapper(*args, **kwargs) -> Any:
param_args = args[1:] if method else args
if method and id_field_name is not None:
if index_on_thread_id:
from dbt_common.context import get_invocation_context

param_args = (get_invocation_context().name,) + param_args
from dbt_common.events.contextvars import get_node_info
node_info = get_node_info()
if node_info and "unique_id" in node_info:
thread_name = node_info["unique_id"]
else:
from dbt_common.context import get_invocation_context
thread_name = get_invocation_context().name
param_args = (thread_name,) + param_args
else:
param_args = (getattr(args[0], id_field_name),) + param_args

Expand Down
7 changes: 7 additions & 0 deletions tests/unit/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,5 +334,12 @@ def test_func(cls, a: int) -> int:

assert recorder._records_by_type["TestAutoRecord"][0].params.a == 1
assert recorder._records_by_type["TestAutoRecord"][0].result.return_val == 3
a = recorder._records_by_type["TestAutoRecord"][0].seq
assert recorder._records_by_type["TestAutoRecord"][1].params.a == 2
assert recorder._records_by_type["TestAutoRecord"][1].result.return_val == 6
assert recorder._records_by_type["TestAutoRecord"][1].seq == a + 1

stream = StringIO()
recorder.write_json(stream)
stream.getvalue()
pass
Loading