diff --git a/.changes/unreleased/Breaking Changes-20250117-144053.yaml b/.changes/unreleased/Breaking Changes-20250117-144053.yaml new file mode 100644 index 00000000..460e7f38 --- /dev/null +++ b/.changes/unreleased/Breaking Changes-20250117-144053.yaml @@ -0,0 +1,6 @@ +kind: Breaking Changes +body: Update `fire_event` to handle `warn_or_error` logic +time: 2025-01-17T14:40:53.08567-06:00 +custom: + Author: QMalcolm + Issue: "236" diff --git a/dbt_common/events/event_manager.py b/dbt_common/events/event_manager.py index 507588f3..41c6599c 100644 --- a/dbt_common/events/event_manager.py +++ b/dbt_common/events/event_manager.py @@ -1,19 +1,42 @@ import os import traceback -from typing import List, Optional, Protocol, Tuple +from typing import Any, List, Optional, Protocol, Tuple from dbt_common.events.base_types import BaseEvent, EventLevel, msg_from_base_event, TCallback from dbt_common.events.logger import LoggerConfig, _Logger, _TextLogger, _JsonLogger, LineFormat +from dbt_common.exceptions.events import EventCompilationError +from dbt_common.helper_types import WarnErrorOptions class EventManager: def __init__(self) -> None: self.loggers: List[_Logger] = [] self.callbacks: List[TCallback] = [] - - def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None: + self.warn_error: bool = False + self.warn_error_options: WarnErrorOptions = WarnErrorOptions(include=[], exclude=[]) + self.require_warn_or_error_handling: bool = False + + def fire_event( + self, + e: BaseEvent, + level: Optional[EventLevel] = None, + node: Any = None, + force_warn_or_error_handling: bool = False, + ) -> None: msg = msg_from_base_event(e, level=level) + if ( + force_warn_or_error_handling or self.require_warn_or_error_handling + ) and msg.info.level == "warn": + event_name = type(e).__name__ + if self.warn_error or self.warn_error_options.includes(event_name): + # This has the potential to create an infinite loop if the handling of the raised + # EventCompilationError fires an event as a warning instead of an error. + raise EventCompilationError(e.message(), node) + elif self.warn_error_options.silenced(event_name): + # Return early if the event is silenced + return + if os.environ.get("DBT_TEST_BINARY_SERIALIZATION"): print(f"--- {msg.info.name}") try: @@ -48,8 +71,17 @@ def flush(self) -> None: class IEventManager(Protocol): callbacks: List[TCallback] loggers: List[_Logger] - - def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None: + warn_error: bool + warn_error_options: WarnErrorOptions + require_warn_or_error_handling: bool + + def fire_event( + self, + e: BaseEvent, + level: Optional[EventLevel] = None, + node: Any = None, + force_warn_or_error_handling: bool = False, + ) -> None: ... def add_logger(self, config: LoggerConfig) -> None: @@ -66,7 +98,9 @@ def __init__(self) -> None: self.event_history: List[Tuple[BaseEvent, Optional[EventLevel]]] = [] self.loggers = [] - def fire_event(self, e: BaseEvent, level: Optional[EventLevel] = None) -> None: + def fire_event( + self, e: BaseEvent, level: Optional[EventLevel] = None, node: Any = None + ) -> None: self.event_history.append((e, level)) def add_logger(self, config: LoggerConfig) -> None: diff --git a/dbt_common/events/functions.py b/dbt_common/events/functions.py index 86d68237..6641d6b8 100644 --- a/dbt_common/events/functions.py +++ b/dbt_common/events/functions.py @@ -1,9 +1,7 @@ from pathlib import Path from dbt_common.events.event_manager_client import get_event_manager -from dbt_common.exceptions import EventCompilationError from dbt_common.invocation import get_invocation_id -from dbt_common.helper_types import WarnErrorOptions from dbt_common.utils.encoding import ForgivingJSONEncoder from dbt_common.events.base_types import BaseEvent, EventLevel, EventMsg from dbt_common.events.logger import LoggerConfig, LineFormat @@ -13,14 +11,12 @@ import json import os import sys -from typing import Callable, Dict, Optional, TextIO, Union +from typing import Any, Callable, Dict, Optional, TextIO, Union from google.protobuf.json_format import MessageToDict LOG_VERSION = 3 metadata_vars: Optional[Dict[str, str]] = None _METADATA_ENV_PREFIX = "DBT_ENV_CUSTOM_ENV_" -WARN_ERROR_OPTIONS = WarnErrorOptions(include=[], exclude=[]) -WARN_ERROR = False # This global, and the following two functions for capturing stdout logs are # an unpleasant hack we intend to remove as part of API-ification. The GitHub @@ -114,12 +110,9 @@ def msg_to_dict(msg: EventMsg) -> dict: return msg_dict +# This function continues to exist to provide backwards compatibility def warn_or_error(event, node=None) -> None: - event_name = type(event).__name__ - if WARN_ERROR or WARN_ERROR_OPTIONS.includes(event_name): - raise EventCompilationError(event.message(), node) - elif not WARN_ERROR_OPTIONS.silenced(event_name): - fire_event(event) + fire_event(e=event, node=node, force_warn_or_error_handling=True) # an alternative to fire_event which only creates and logs the event value @@ -135,8 +128,15 @@ def fire_event_if( # this is where all the side effects happen branched by event type # (i.e. - mutating the event history, printing to stdout, logging # to files, etc.) -def fire_event(e: BaseEvent, level: Optional[EventLevel] = None) -> None: - get_event_manager().fire_event(e, level=level) +def fire_event( + e: BaseEvent, + level: Optional[EventLevel] = None, + node: Any = None, + force_warn_or_error_handling: bool = False, +) -> None: + get_event_manager().fire_event( + e, level=level, node=node, force_warn_or_error_handling=force_warn_or_error_handling + ) def get_metadata_vars() -> Dict[str, str]: diff --git a/tests/unit/test_functions.py b/tests/unit/test_functions.py index 6c4126a1..63566b74 100644 --- a/tests/unit/test_functions.py +++ b/tests/unit/test_functions.py @@ -3,7 +3,7 @@ from dbt_common.events import functions from dbt_common.events.base_types import EventLevel, WarnLevel from dbt_common.events.event_manager import EventManager -from dbt_common.events.event_manager_client import ctx_set_event_manager +from dbt_common.events.event_manager_client import ctx_set_event_manager, get_event_manager from dbt_common.exceptions import EventCompilationError from dbt_common.helper_types import WarnErrorOptions from tests.unit.utils import EventCatcher @@ -38,9 +38,44 @@ def valid_error_names() -> Set[str]: return {Note.__name__} -class TestWarnOrError: +class TestFireEvent: + @pytest.mark.parametrize( + "force_warn_or_error_handling,require_warn_or_error_handling,should_raise", + [ + (True, True, True), + (True, False, True), + (False, True, True), + (False, False, False), + ], + ) + def test_warning_handling( + self, + set_event_manager_with_catcher: None, + force_warn_or_error_handling: bool, + require_warn_or_error_handling: bool, + should_raise: bool, + ) -> None: + manager = get_event_manager() + manager.warn_error = True + manager.require_warn_or_error_handling = require_warn_or_error_handling + try: + functions.fire_event( + e=Note(msg="hi"), force_warn_or_error_handling=force_warn_or_error_handling + ) + except EventCompilationError: + assert ( + should_raise + ), "`fire_event` raised an error from a warning when it shouldn't have" + return + + assert ( + not should_raise + ), "`fire_event` didn't raise an error from a warning when it should have" + + +class TestDeprecatedWarnOrError: def test_fires_error(self, valid_error_names: Set[str]) -> None: - functions.WARN_ERROR_OPTIONS = WarnErrorOptions( + get_event_manager().warn_error_options = WarnErrorOptions( include="*", valid_error_names=valid_error_names ) with pytest.raises(EventCompilationError): @@ -52,7 +87,7 @@ def test_fires_warning( event_catcher: EventCatcher, set_event_manager_with_catcher: None, ) -> None: - functions.WARN_ERROR_OPTIONS = WarnErrorOptions( + get_event_manager().warn_error_options = WarnErrorOptions( include="*", exclude=list(valid_error_names), valid_error_names=valid_error_names ) functions.warn_or_error(Note(msg="hi")) @@ -65,7 +100,7 @@ def test_silenced( event_catcher: EventCatcher, set_event_manager_with_catcher: None, ) -> None: - functions.WARN_ERROR_OPTIONS = WarnErrorOptions( + get_event_manager().warn_error_options = WarnErrorOptions( include="*", silence=list(valid_error_names), valid_error_names=valid_error_names ) functions.warn_or_error(Note(msg="hi"))