Skip to content

Commit

Permalink
Refine support for classmethods in record/replay (#243)
Browse files Browse the repository at this point in the history
* Refine support for classmethods in record/replay

* Add changelog entry.

* Remove gate which did not work correctly.
  • Loading branch information
peterallenwebb authored Feb 5, 2025
1 parent 3956ae7 commit 397fb3f
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 25 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20250204-160452.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Refine support for classmethods in record/replay
time: 2025-02-04T16:04:52.690444-05:00
custom:
Author: peterallenwebb
Issue: "243"
1 change: 1 addition & 0 deletions dbt_common/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, env: Mapping[str, str]):
else:
self._env = env_public

self.name = "unset"
self._env_secrets: Optional[List[str]] = None
self._env_private = env_private
self.recorder: Optional[Recorder] = None
Expand Down
77 changes: 54 additions & 23 deletions dbt_common/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import inspect
import json
import os
import threading

from enum import Enum
from typing import Any, Callable, Dict, List, Mapping, Optional, TextIO, Tuple, Type
Expand Down Expand Up @@ -325,7 +324,14 @@ def auto_record_function(
needed. That makes it suitable for quickly adding record support to simple
functions with simple parameters."""
return functools.partial(
_record_function_inner, record_name, method, False, None, group, index_on_thread_name
_record_function_inner,
record_name,
method,
False,
None,
group,
index_on_thread_name,
False,
)


Expand All @@ -339,7 +345,14 @@ def record_function(
have their function calls recorded during record mode, and mocked out with
previously recorded replay data during replay."""
return functools.partial(
_record_function_inner, record_type, method, tuple_result, id_field_name, None, False
_record_function_inner,
record_type,
method,
tuple_result,
id_field_name,
None,
False,
False,
)


Expand Down Expand Up @@ -383,12 +396,15 @@ def _from_dict(self, data):


def _record_function_inner(
record_type, method, tuple_result, id_field_name, group, index_on_thread_id, func_to_record
record_type,
method,
tuple_result,
id_field_name,
group,
index_on_thread_id,
is_classmethod,
func_to_record,
):
# When record/replay is not active, do nothing.
if get_record_mode_from_env() is None:
return func_to_record

if isinstance(record_type, str):
return_type = inspect.signature(func_to_record).return_annotation
fields = _get_arg_fields(inspect.getfullargspec(func_to_record), method)
Expand Down Expand Up @@ -426,21 +442,25 @@ def record_replay_wrapper(*args, **kwargs) -> Any:
except LookupError:
pass

call_args = args[1:] if is_classmethod else args

if recorder is None:
return func_to_record(*args, **kwargs)
return func_to_record(*call_args, **kwargs)

if recorder.recorded_types is not None and not (
record_type.__name__ in recorder.recorded_types
or record_type.group in recorder.recorded_types
):
return func_to_record(*args, **kwargs)
return func_to_record(*call_args, **kwargs)

# For methods, peel off the 'self' argument before calling the
# params constructor.
param_args = args[1:] if method else args
if method and id_field_name is not None:
if index_on_thread_id:
param_args = (threading.current_thread().name,) + param_args
from dbt_common.context import get_invocation_context

param_args = (get_invocation_context().name,) + param_args
else:
param_args = (getattr(args[0], id_field_name),) + param_args

Expand All @@ -451,15 +471,15 @@ def record_replay_wrapper(*args, **kwargs) -> Any:
include = params._include()

if not include:
return func_to_record(*args, **kwargs)
return func_to_record(*call_args, **kwargs)

if recorder.mode == RecorderMode.REPLAY:
return recorder.expect_record(params)
if RECORDED_BY_HIGHER_FUNCTION.get():
return func_to_record(*args, **kwargs)
return func_to_record(*call_args, **kwargs)

RECORDED_BY_HIGHER_FUNCTION.set(True)
r = func_to_record(*args, **kwargs)
r = func_to_record(*call_args, **kwargs)
result = (
None
if record_type.result_cls is None
Expand Down Expand Up @@ -487,6 +507,11 @@ def record_replay_wrapper(*args, **kwargs) -> Any:
return record_replay_wrapper


def _is_classmethod(method):
b = inspect.ismethod(method) and isinstance(method.__self__, type)
return b


def supports_replay(cls):
"""Class decorator which adds record/replay support for a class. In particular,
this decorator ensures that calls to overriden functions are still recorded."""
Expand All @@ -507,19 +532,25 @@ def wrapping_init_subclass(sub_cls):
metadata = getattr(method, "_record_metadata", None)
if method and getattr(method, "_record_metadata", None):
sub_method = getattr(sub_cls, method_name, None)
recorded_sub_method = _record_function_inner(
metadata["record_type"],
metadata["method"],
metadata["tuple_result"],
metadata["id_field_name"],
metadata["group"],
metadata["index_on_thread_id"],
_is_classmethod(method),
sub_method,
)

if _is_classmethod(method):
recorded_sub_method = classmethod(recorded_sub_method)

if sub_method is not None:
setattr(
sub_cls,
method_name,
_record_function_inner(
metadata["record_type"],
metadata["method"],
metadata["tuple_result"],
metadata["id_field_name"],
metadata["group"],
metadata["index_on_thread_id"],
sub_method,
),
recorded_sub_method,
)

original_init_subclass()
Expand Down
56 changes: 54 additions & 2 deletions tests/unit/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,16 @@ class RecordableSubclass(Recordable):
def test_func(self, a: int) -> int:
return 3 * a

rs = RecordableSubclass()
class RecordableSubSubclass(RecordableSubclass):
def test_func(self, a: int) -> int:
return 4 * a

rs = RecordableSubSubclass()

rs.test_func(1)

assert recorder._records_by_type["TestAutoRecord"][-1].params.a == 1
assert recorder._records_by_type["TestAutoRecord"][-1].result.return_val == 3
assert recorder._records_by_type["TestAutoRecord"][-1].result.return_val == 4


class CustomType:
Expand Down Expand Up @@ -284,3 +288,51 @@ def test_func(a: CustomType) -> CustomType:
recorder.write_json(buffer)
buffer.seek(0)
recorder.load_json(buffer)


def test_record_classmethod() -> None:
os.environ["DBT_RECORDER_MODE"] = "Record"
recorder = Recorder(RecorderMode.RECORD, None)
set_invocation_context({})
get_invocation_context().recorder = recorder

@supports_replay
class Recordable:
@classmethod
@auto_record_function("TestAuto")
def test_func(cls, a: int) -> int:
return 2 * a

Recordable.test_func(1)

assert recorder._records_by_type["TestAutoRecord"][-1].params.a == 1
assert recorder._records_by_type["TestAutoRecord"][-1].result.return_val == 2


def test_record_classmethod_override() -> None:
os.environ["DBT_RECORDER_MODE"] = "Record"
recorder = Recorder(RecorderMode.RECORD, None)
set_invocation_context({})
get_invocation_context().recorder = recorder

@supports_replay
class Recordable:
@classmethod
@auto_record_function("TestAuto")
def test_func(cls, a: int) -> int:
return 2 * a

class RecordableSubclass(Recordable):
@classmethod
def test_func(cls, a: int) -> int:
return 3 * a

RecordableSubclass.test_func(1)

rs = RecordableSubclass()
rs.test_func(2)

assert recorder._records_by_type["TestAutoRecord"][0].params.a == 1
assert recorder._records_by_type["TestAutoRecord"][0].result.return_val == 3
assert recorder._records_by_type["TestAutoRecord"][1].params.a == 2
assert recorder._records_by_type["TestAutoRecord"][1].result.return_val == 6

0 comments on commit 397fb3f

Please sign in to comment.