diff --git a/.changes/unreleased/Fixes-20230710-172547.yaml b/.changes/unreleased/Fixes-20230710-172547.yaml new file mode 100644 index 00000000000..f947cc7897d --- /dev/null +++ b/.changes/unreleased/Fixes-20230710-172547.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Fix fail-fast behavior (including retry) +time: 2023-07-10T17:25:47.912129-05:00 +custom: + Author: aranke + Issue: "7785" diff --git a/core/dbt/contracts/results.py b/core/dbt/contracts/results.py index fea3bb30e28..aaa036e6a74 100644 --- a/core/dbt/contracts/results.py +++ b/core/dbt/contracts/results.py @@ -1,3 +1,5 @@ +import threading + from dbt.contracts.graph.unparsed import FreshnessThreshold from dbt.contracts.graph.nodes import SourceDefinition, ResultNode from dbt.contracts.util import ( @@ -161,6 +163,20 @@ class RunResult(NodeResult): def skipped(self): return self.status == RunStatus.Skipped + @classmethod + def from_node(cls, node: ResultNode, status: RunStatus, message: Optional[str]): + thread_id = threading.current_thread().name + return RunResult( + status=status, + thread_id=thread_id, + execution_time=0, + timing=[], + message=message, + node=node, + adapter_response={}, + failures=None, + ) + @dataclass class ExecutionResult(dbtClassMixin): diff --git a/core/dbt/task/base.py b/core/dbt/task/base.py index a7ec1e046db..0aae0bd8851 100644 --- a/core/dbt/task/base.py +++ b/core/dbt/task/base.py @@ -296,19 +296,6 @@ def from_run_result(self, result, start_time, timing_info): failures=result.failures, ) - def skip_result(self, node, message): - thread_id = threading.current_thread().name - return RunResult( - status=RunStatus.Skipped, - thread_id=thread_id, - execution_time=0, - timing=[], - message=message, - node=node, - adapter_response={}, - failures=None, - ) - def compile_and_execute(self, manifest, ctx): result = None with self.adapter.connection_for(self.node) if get_flags().INTROSPECT else nullcontext(): @@ -483,7 +470,7 @@ def on_skip(self): ) ) - node_result = self.skip_result(self.node, error_message) + node_result = RunResult.from_node(self.node, RunStatus.Skipped, error_message) return node_result def do_skip(self, cause=None): diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 5c51a11d31c..fb6f185d355 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -1,29 +1,28 @@ import os import time -from pathlib import Path from abc import abstractmethod from concurrent.futures import as_completed from datetime import datetime from multiprocessing.dummy import Pool as ThreadPool +from pathlib import Path from typing import Optional, Dict, List, Set, Tuple, Iterable, AbstractSet -from .printer import ( - print_run_result_error, - print_run_end_messages, -) - -from dbt.task.base import ConfiguredTask +import dbt.exceptions +import dbt.tracking +import dbt.utils from dbt.adapters.base import BaseRelation from dbt.adapters.factory import get_adapter -from dbt.logger import ( - DbtProcessState, - TextOnly, - UniqueID, - TimestampNamed, - DbtModelState, - ModelMetadata, - NodeCount, +from dbt.contracts.graph.manifest import WritableManifest +from dbt.contracts.graph.nodes import ResultNode +from dbt.contracts.results import ( + NodeStatus, + RunExecutionResult, + RunningStatus, + RunResult, + RunStatus, ) +from dbt.contracts.state import PreviousState +from dbt.events.contextvars import log_contextvars, task_contextvars from dbt.events.functions import fire_event, warn_or_error from dbt.events.types import ( Formatting, @@ -36,25 +35,29 @@ EndRunResult, NothingToDo, ) -from dbt.events.contextvars import log_contextvars, task_contextvars -from dbt.contracts.graph.nodes import ResultNode -from dbt.contracts.results import NodeStatus, RunExecutionResult, RunningStatus -from dbt.contracts.state import PreviousState from dbt.exceptions import ( DbtInternalError, NotImplementedError, DbtRuntimeError, FailFastError, ) - +from dbt.flags import get_flags from dbt.graph import GraphQueue, NodeSelector, SelectionSpec, parse_difference +from dbt.logger import ( + DbtProcessState, + TextOnly, + UniqueID, + TimestampNamed, + DbtModelState, + ModelMetadata, + NodeCount, +) from dbt.parser.manifest import write_manifest -import dbt.tracking - -import dbt.exceptions -from dbt.flags import get_flags -import dbt.utils -from dbt.contracts.graph.manifest import WritableManifest +from dbt.task.base import ConfiguredTask +from .printer import ( + print_run_result_error, + print_run_end_messages, +) RESULT_FILE_NAME = "run_results.json" RUNNING_STATE = DbtProcessState("running") @@ -360,21 +363,27 @@ def execute_nodes(self): pool = ThreadPool(num_threads) try: self.run_queue(pool) - except FailFastError as failure: self._cancel_connections(pool) + + executed_node_ids = [r.node.unique_id for r in self.node_results] + + for r in self._flattened_nodes: + if r.unique_id not in executed_node_ids: + self.node_results.append( + RunResult.from_node(r, RunStatus.Skipped, "Skipping due to fail_fast") + ) + print_run_result_error(failure.result) raise - except KeyboardInterrupt: self._cancel_connections(pool) print_run_end_messages(self.node_results, keyboard_interrupt=True) raise - - pool.close() - pool.join() - - return self.node_results + finally: + pool.close() + pool.join() + return self.node_results def _mark_dependent_errors(self, node_id, result, cause): if self.graph is None: diff --git a/tests/functional/dependencies/test_local_dependency.py b/tests/functional/dependencies/test_local_dependency.py index c7a9f01cc0a..c93d4e2713b 100644 --- a/tests/functional/dependencies/test_local_dependency.py +++ b/tests/functional/dependencies/test_local_dependency.py @@ -13,6 +13,7 @@ import dbt.semver import dbt.config import dbt.exceptions +from dbt.contracts.results import RunStatus from dbt.tests.util import check_relations_equal, run_dbt, run_dbt_and_capture @@ -207,8 +208,9 @@ def models(self): def test_missing_dependency(self, project): # dbt should raise a runtime exception - with pytest.raises(dbt.exceptions.DbtRuntimeError): - run_dbt(["compile"]) + res = run_dbt(["compile"], expect_pass=False) + assert len(res) == 1 + assert res[0].status == RunStatus.Error class TestSimpleDependencyWithSchema(BaseDependencyTest): diff --git a/tests/functional/fail_fast/test_fail_fast_run.py b/tests/functional/fail_fast/test_fail_fast_run.py index ad0a84e169d..ea956a2d540 100644 --- a/tests/functional/fail_fast/test_fail_fast_run.py +++ b/tests/functional/fail_fast/test_fail_fast_run.py @@ -1,12 +1,10 @@ -import pytest import json from pathlib import Path +import pytest -from dbt.contracts.results import RunResult from dbt.tests.util import run_dbt - models__one_sql = """ select 1 """ @@ -30,8 +28,11 @@ def test_fail_fast_run( models, # noqa: F811 ): res = run_dbt(["run", "--fail-fast", "--threads", "1"], expect_pass=False) - # a RunResult contains only one node so we can be sure only one model was run - assert type(res) == RunResult + assert {r.node.unique_id: r.status for r in res.results} == { + "model.test.one": "success", + "model.test.two": "error", + } + run_results_file = Path(project.project_root) / "target/run_results.json" assert run_results_file.is_file() with run_results_file.open() as run_results_str: @@ -57,5 +58,7 @@ def test_fail_fast_run_user_config( models, # noqa: F811 ): res = run_dbt(["run", "--threads", "1"], expect_pass=False) - # a RunResult contains only one node so we can be sure only one model was run - assert type(res) == RunResult + assert {r.node.unique_id: r.status for r in res.results} == { + "model.test.one": "success", + "model.test.two": "error", + } diff --git a/tests/functional/retry/test_retry.py b/tests/functional/retry/test_retry.py index 238cd6fdf07..56197194046 100644 --- a/tests/functional/retry/test_retry.py +++ b/tests/functional/retry/test_retry.py @@ -145,22 +145,6 @@ def test_run_operation(self, project): results = run_dbt(["retry"], expect_pass=False) assert {n.unique_id: n.status for n in results.results} == expected_statuses - def test_fail_fast(self, project): - result = run_dbt(["--warn-error", "build", "--fail-fast"], expect_pass=False) - - assert result.status == RunStatus.Error - assert result.node.name == "sample_model" - - results = run_dbt(["retry"], expect_pass=False) - - assert len(results.results) == 1 - assert results.results[0].status == RunStatus.Error - assert results.results[0].node.name == "sample_model" - - result = run_dbt(["retry", "--fail-fast"], expect_pass=False) - assert result.status == RunStatus.Error - assert result.node.name == "sample_model" - def test_removed_file(self, project): run_dbt(["build"], expect_pass=False) @@ -180,3 +164,57 @@ def test_removed_file_leaf_node(self, project): rm_file("models", "third_model.sql") with pytest.raises(ValueError, match="Couldn't find model 'model.test.third_model'"): run_dbt(["retry"], expect_pass=False) + + +class TestFailFast: + @pytest.fixture(scope="class") + def models(self): + return { + "sample_model.sql": models__sample_model, + "second_model.sql": models__second_model, + "union_model.sql": models__union_model, + "final_model.sql": "select * from {{ ref('union_model') }};", + } + + def test_fail_fast(self, project): + results = run_dbt(["--fail-fast", "build"], expect_pass=False) + assert {r.node.unique_id: r.status for r in results.results} == { + "model.test.sample_model": RunStatus.Error, + "model.test.second_model": RunStatus.Success, + "model.test.union_model": RunStatus.Skipped, + "model.test.final_model": RunStatus.Skipped, + } + + # Check that retry inherits fail-fast from upstream command (build) + results = run_dbt(["retry"], expect_pass=False) + assert {r.node.unique_id: r.status for r in results.results} == { + "model.test.sample_model": RunStatus.Error, + "model.test.union_model": RunStatus.Skipped, + "model.test.final_model": RunStatus.Skipped, + } + + fixed_sql = "select 1 as id, 1 as foo" + write_file(fixed_sql, "models", "sample_model.sql") + + results = run_dbt(["retry"], expect_pass=False) + assert {r.node.unique_id: r.status for r in results.results} == { + "model.test.sample_model": RunStatus.Success, + "model.test.union_model": RunStatus.Success, + "model.test.final_model": RunStatus.Error, + } + + results = run_dbt(["retry"], expect_pass=False) + assert {r.node.unique_id: r.status for r in results.results} == { + "model.test.final_model": RunStatus.Error, + } + + fixed_sql = "select * from {{ ref('union_model') }}" + write_file(fixed_sql, "models", "final_model.sql") + + results = run_dbt(["retry"]) + assert {r.node.unique_id: r.status for r in results.results} == { + "model.test.final_model": RunStatus.Success, + } + + results = run_dbt(["retry"]) + assert {r.node.unique_id: r.status for r in results.results} == {}