Skip to content

Commit

Permalink
Fixes #7785: fail-fast behavior (#8066)
Browse files Browse the repository at this point in the history
  • Loading branch information
aranke authored Jul 11, 2023
1 parent fd233ea commit 07c3dcd
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 72 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230710-172547.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Fix fail-fast behavior (including retry)
time: 2023-07-10T17:25:47.912129-05:00
custom:
Author: aranke
Issue: "7785"
16 changes: 16 additions & 0 deletions core/dbt/contracts/results.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import threading

from dbt.contracts.graph.unparsed import FreshnessThreshold
from dbt.contracts.graph.nodes import SourceDefinition, ResultNode
from dbt.contracts.util import (
Expand Down Expand Up @@ -161,6 +163,20 @@ class RunResult(NodeResult):
def skipped(self):
return self.status == RunStatus.Skipped

@classmethod
def from_node(cls, node: ResultNode, status: RunStatus, message: Optional[str]):
thread_id = threading.current_thread().name
return RunResult(
status=status,
thread_id=thread_id,
execution_time=0,
timing=[],
message=message,
node=node,
adapter_response={},
failures=None,
)


@dataclass
class ExecutionResult(dbtClassMixin):
Expand Down
15 changes: 1 addition & 14 deletions core/dbt/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,19 +296,6 @@ def from_run_result(self, result, start_time, timing_info):
failures=result.failures,
)

def skip_result(self, node, message):
thread_id = threading.current_thread().name
return RunResult(
status=RunStatus.Skipped,
thread_id=thread_id,
execution_time=0,
timing=[],
message=message,
node=node,
adapter_response={},
failures=None,
)

def compile_and_execute(self, manifest, ctx):
result = None
with self.adapter.connection_for(self.node) if get_flags().INTROSPECT else nullcontext():
Expand Down Expand Up @@ -483,7 +470,7 @@ def on_skip(self):
)
)

node_result = self.skip_result(self.node, error_message)
node_result = RunResult.from_node(self.node, RunStatus.Skipped, error_message)
return node_result

def do_skip(self, cause=None):
Expand Down
75 changes: 42 additions & 33 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
import os
import time
from pathlib import Path
from abc import abstractmethod
from concurrent.futures import as_completed
from datetime import datetime
from multiprocessing.dummy import Pool as ThreadPool
from pathlib import Path
from typing import Optional, Dict, List, Set, Tuple, Iterable, AbstractSet

from .printer import (
print_run_result_error,
print_run_end_messages,
)

from dbt.task.base import ConfiguredTask
import dbt.exceptions
import dbt.tracking
import dbt.utils
from dbt.adapters.base import BaseRelation
from dbt.adapters.factory import get_adapter
from dbt.logger import (
DbtProcessState,
TextOnly,
UniqueID,
TimestampNamed,
DbtModelState,
ModelMetadata,
NodeCount,
from dbt.contracts.graph.manifest import WritableManifest
from dbt.contracts.graph.nodes import ResultNode
from dbt.contracts.results import (
NodeStatus,
RunExecutionResult,
RunningStatus,
RunResult,
RunStatus,
)
from dbt.contracts.state import PreviousState
from dbt.events.contextvars import log_contextvars, task_contextvars
from dbt.events.functions import fire_event, warn_or_error
from dbt.events.types import (
Formatting,
Expand All @@ -36,25 +35,29 @@
EndRunResult,
NothingToDo,
)
from dbt.events.contextvars import log_contextvars, task_contextvars
from dbt.contracts.graph.nodes import ResultNode
from dbt.contracts.results import NodeStatus, RunExecutionResult, RunningStatus
from dbt.contracts.state import PreviousState
from dbt.exceptions import (
DbtInternalError,
NotImplementedError,
DbtRuntimeError,
FailFastError,
)

from dbt.flags import get_flags
from dbt.graph import GraphQueue, NodeSelector, SelectionSpec, parse_difference
from dbt.logger import (
DbtProcessState,
TextOnly,
UniqueID,
TimestampNamed,
DbtModelState,
ModelMetadata,
NodeCount,
)
from dbt.parser.manifest import write_manifest
import dbt.tracking

import dbt.exceptions
from dbt.flags import get_flags
import dbt.utils
from dbt.contracts.graph.manifest import WritableManifest
from dbt.task.base import ConfiguredTask
from .printer import (
print_run_result_error,
print_run_end_messages,
)

RESULT_FILE_NAME = "run_results.json"
RUNNING_STATE = DbtProcessState("running")
Expand Down Expand Up @@ -360,21 +363,27 @@ def execute_nodes(self):
pool = ThreadPool(num_threads)
try:
self.run_queue(pool)

except FailFastError as failure:
self._cancel_connections(pool)

executed_node_ids = [r.node.unique_id for r in self.node_results]

for r in self._flattened_nodes:
if r.unique_id not in executed_node_ids:
self.node_results.append(
RunResult.from_node(r, RunStatus.Skipped, "Skipping due to fail_fast")
)

print_run_result_error(failure.result)
raise

except KeyboardInterrupt:
self._cancel_connections(pool)
print_run_end_messages(self.node_results, keyboard_interrupt=True)
raise

pool.close()
pool.join()

return self.node_results
finally:
pool.close()
pool.join()
return self.node_results

def _mark_dependent_errors(self, node_id, result, cause):
if self.graph is None:
Expand Down
6 changes: 4 additions & 2 deletions tests/functional/dependencies/test_local_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import dbt.semver
import dbt.config
import dbt.exceptions
from dbt.contracts.results import RunStatus

from dbt.tests.util import check_relations_equal, run_dbt, run_dbt_and_capture

Expand Down Expand Up @@ -207,8 +208,9 @@ def models(self):

def test_missing_dependency(self, project):
# dbt should raise a runtime exception
with pytest.raises(dbt.exceptions.DbtRuntimeError):
run_dbt(["compile"])
res = run_dbt(["compile"], expect_pass=False)
assert len(res) == 1
assert res[0].status == RunStatus.Error


class TestSimpleDependencyWithSchema(BaseDependencyTest):
Expand Down
17 changes: 10 additions & 7 deletions tests/functional/fail_fast/test_fail_fast_run.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import pytest
import json
from pathlib import Path

import pytest

from dbt.contracts.results import RunResult
from dbt.tests.util import run_dbt


models__one_sql = """
select 1
"""
Expand All @@ -30,8 +28,11 @@ def test_fail_fast_run(
models, # noqa: F811
):
res = run_dbt(["run", "--fail-fast", "--threads", "1"], expect_pass=False)
# a RunResult contains only one node so we can be sure only one model was run
assert type(res) == RunResult
assert {r.node.unique_id: r.status for r in res.results} == {
"model.test.one": "success",
"model.test.two": "error",
}

run_results_file = Path(project.project_root) / "target/run_results.json"
assert run_results_file.is_file()
with run_results_file.open() as run_results_str:
Expand All @@ -57,5 +58,7 @@ def test_fail_fast_run_user_config(
models, # noqa: F811
):
res = run_dbt(["run", "--threads", "1"], expect_pass=False)
# a RunResult contains only one node so we can be sure only one model was run
assert type(res) == RunResult
assert {r.node.unique_id: r.status for r in res.results} == {
"model.test.one": "success",
"model.test.two": "error",
}
70 changes: 54 additions & 16 deletions tests/functional/retry/test_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,22 +145,6 @@ def test_run_operation(self, project):
results = run_dbt(["retry"], expect_pass=False)
assert {n.unique_id: n.status for n in results.results} == expected_statuses

def test_fail_fast(self, project):
result = run_dbt(["--warn-error", "build", "--fail-fast"], expect_pass=False)

assert result.status == RunStatus.Error
assert result.node.name == "sample_model"

results = run_dbt(["retry"], expect_pass=False)

assert len(results.results) == 1
assert results.results[0].status == RunStatus.Error
assert results.results[0].node.name == "sample_model"

result = run_dbt(["retry", "--fail-fast"], expect_pass=False)
assert result.status == RunStatus.Error
assert result.node.name == "sample_model"

def test_removed_file(self, project):
run_dbt(["build"], expect_pass=False)

Expand All @@ -180,3 +164,57 @@ def test_removed_file_leaf_node(self, project):
rm_file("models", "third_model.sql")
with pytest.raises(ValueError, match="Couldn't find model 'model.test.third_model'"):
run_dbt(["retry"], expect_pass=False)


class TestFailFast:
@pytest.fixture(scope="class")
def models(self):
return {
"sample_model.sql": models__sample_model,
"second_model.sql": models__second_model,
"union_model.sql": models__union_model,
"final_model.sql": "select * from {{ ref('union_model') }};",
}

def test_fail_fast(self, project):
results = run_dbt(["--fail-fast", "build"], expect_pass=False)
assert {r.node.unique_id: r.status for r in results.results} == {
"model.test.sample_model": RunStatus.Error,
"model.test.second_model": RunStatus.Success,
"model.test.union_model": RunStatus.Skipped,
"model.test.final_model": RunStatus.Skipped,
}

# Check that retry inherits fail-fast from upstream command (build)
results = run_dbt(["retry"], expect_pass=False)
assert {r.node.unique_id: r.status for r in results.results} == {
"model.test.sample_model": RunStatus.Error,
"model.test.union_model": RunStatus.Skipped,
"model.test.final_model": RunStatus.Skipped,
}

fixed_sql = "select 1 as id, 1 as foo"
write_file(fixed_sql, "models", "sample_model.sql")

results = run_dbt(["retry"], expect_pass=False)
assert {r.node.unique_id: r.status for r in results.results} == {
"model.test.sample_model": RunStatus.Success,
"model.test.union_model": RunStatus.Success,
"model.test.final_model": RunStatus.Error,
}

results = run_dbt(["retry"], expect_pass=False)
assert {r.node.unique_id: r.status for r in results.results} == {
"model.test.final_model": RunStatus.Error,
}

fixed_sql = "select * from {{ ref('union_model') }}"
write_file(fixed_sql, "models", "final_model.sql")

results = run_dbt(["retry"])
assert {r.node.unique_id: r.status for r in results.results} == {
"model.test.final_model": RunStatus.Success,
}

results = run_dbt(["retry"])
assert {r.node.unique_id: r.status for r in results.results} == {}

0 comments on commit 07c3dcd

Please sign in to comment.