Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC] Unit Testing #8188

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from dbt.task.show import ShowTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask
from dbt.task.unit_test import UnitTestTask


@dataclass
Expand Down Expand Up @@ -845,6 +846,52 @@
return results, success


# dbt test
@cli.command("unit-test")
@click.pass_context
@p.defer
@p.deprecated_defer
@p.exclude
@p.fail_fast
@p.favor_state
@p.deprecated_favor_state
@p.indirect_selection
@p.show_output_format
@p.profile
@p.profiles_dir
@p.project_dir
@p.select
@p.selector
@p.state
@p.defer_state
@p.deprecated_state
@p.store_failures
@p.target
@p.target_path
@p.threads
@p.vars
@p.version_check
@requires.postflight
@requires.preflight
@requires.profile
@requires.project
@requires.runtime_config
@requires.manifest
@requires.unit_test_collection
def unit_test(ctx, **kwargs):
"""Runs tests on data in deployed models. Run this after `dbt run`"""
task = UnitTestTask(

Check warning on line 883 in core/dbt/cli/main.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/cli/main.py#L883

Added line #L883 was not covered by tests
ctx.obj["flags"],
ctx.obj["runtime_config"],
ctx.obj["manifest"],
ctx.obj["unit_test_collection"],
)

results = task.run()
success = task.interpret_results(results)
return results, success

Check warning on line 892 in core/dbt/cli/main.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/cli/main.py#L890-L892

Added lines #L890 - L892 were not covered by tests


# Support running as a module
if __name__ == "__main__":
cli()
23 changes: 23 additions & 0 deletions core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from dbt.profiler import profiler
from dbt.tracking import active_user, initialize_from_flags, track_run
from dbt.utils import cast_dict_to_dict_of_strings
from dbt.parser.unit_tests import UnitTestManifestLoader
from dbt.plugins import set_up_plugin_manager, get_plugin_manager

from click import Context
Expand Down Expand Up @@ -265,3 +266,25 @@
if len(args0) == 0:
return outer_wrapper
return outer_wrapper(args0[0])


def unit_test_collection(func):
"""A decorator used by click command functions for generating a unit test collection provided a manifest"""

def wrapper(*args, **kwargs):
ctx = args[0]
assert isinstance(ctx, Context)

Check warning on line 276 in core/dbt/cli/requires.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/cli/requires.py#L275-L276

Added lines #L275 - L276 were not covered by tests

req_strs = ["manifest", "runtime_config"]
reqs = [ctx.obj.get(req_str) for req_str in req_strs]

Check warning on line 279 in core/dbt/cli/requires.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/cli/requires.py#L278-L279

Added lines #L278 - L279 were not covered by tests

if None in reqs:
raise DbtProjectError("manifest and runtime_config required for unit_test_collection")

Check warning on line 282 in core/dbt/cli/requires.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/cli/requires.py#L281-L282

Added lines #L281 - L282 were not covered by tests

collection = UnitTestManifestLoader.load(ctx.obj["manifest"], ctx.obj["runtime_config"])

Check warning on line 284 in core/dbt/cli/requires.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/cli/requires.py#L284

Added line #L284 was not covered by tests

ctx.obj["unit_test_collection"] = collection

Check warning on line 286 in core/dbt/cli/requires.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/cli/requires.py#L286

Added line #L286 was not covered by tests

return func(*args, **kwargs)

Check warning on line 288 in core/dbt/cli/requires.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/cli/requires.py#L288

Added line #L288 was not covered by tests

return update_wrapper(wrapper, func)
20 changes: 20 additions & 0 deletions core/dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,26 @@
return self.call_macro(*args, **kwargs)


class UnitTestMacroGenerator(MacroGenerator):
# this makes UnitTestMacroGenerator objects callable like functions
def __init__(
self,
macro_generator: MacroGenerator,
call_return_value: Any,
) -> None:
super().__init__(

Check warning on line 340 in core/dbt/clients/jinja.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/clients/jinja.py#L340

Added line #L340 was not covered by tests
macro_generator.macro,
macro_generator.context,
macro_generator.node,
macro_generator.stack,
)
self.call_return_value = call_return_value

Check warning on line 346 in core/dbt/clients/jinja.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/clients/jinja.py#L346

Added line #L346 was not covered by tests

def __call__(self, *args, **kwargs):
with self.track_call():
return self.call_return_value

Check warning on line 350 in core/dbt/clients/jinja.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/clients/jinja.py#L349-L350

Added lines #L349 - L350 were not covered by tests


class QueryStringGenerator(BaseMacroGenerator):
def __init__(self, template_str: str, context: Dict[str, Any]) -> None:
super().__init__(context)
Expand Down
13 changes: 10 additions & 3 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from dbt.adapters.factory import get_adapter
from dbt.clients import jinja
from dbt.clients.system import make_directory
from dbt.context.providers import generate_runtime_model_context
from dbt.context.providers import (
generate_runtime_model_context,
generate_runtime_unit_test_context,
)
from dbt.contracts.graph.manifest import Manifest, UniqueID
from dbt.contracts.graph.nodes import (
ManifestNode,
Expand All @@ -22,6 +25,7 @@
GraphMemberNode,
InjectedCTE,
SeedNode,
UnitTestNode,
)
from dbt.exceptions import (
GraphDependencyNotFoundError,
Expand All @@ -44,6 +48,7 @@
names = {
NodeType.Model: "model",
NodeType.Test: "test",
NodeType.Unit: "unit test",
NodeType.Snapshot: "snapshot",
NodeType.Analysis: "analysis",
NodeType.Macro: "macro",
Expand Down Expand Up @@ -289,8 +294,10 @@
manifest: Manifest,
extra_context: Dict[str, Any],
) -> Dict[str, Any]:

context = generate_runtime_model_context(node, self.config, manifest)
if isinstance(node, UnitTestNode):
context = generate_runtime_unit_test_context(node, self.config, manifest)

Check warning on line 298 in core/dbt/compilation.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/compilation.py#L298

Added line #L298 was not covered by tests
else:
context = generate_runtime_model_context(node, self.config, manifest)
context.update(extra_context)

if isinstance(node, GenericTestNode):
Expand Down
82 changes: 79 additions & 3 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
from copy import deepcopy
import os
from typing import (
Callable,
Expand All @@ -17,7 +18,7 @@
from dbt.adapters.base.column import Column
from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names
from dbt.clients import agate_helper
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack, UnitTestMacroGenerator
from dbt.config import RuntimeConfig, Project
from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER
from dbt.context.base import contextmember, contextproperty, Var
Expand All @@ -39,6 +40,7 @@
RefArgs,
AccessType,
SemanticModel,
UnitTestNode,
)
from dbt.contracts.graph.metrics import MetricReference, ResolvedMetricReference
from dbt.contracts.graph.unparsed import NodeVersion
Expand Down Expand Up @@ -566,6 +568,17 @@
return super().create_relation(target_model)


class RuntimeUnitTestRefResolver(RuntimeRefResolver):
def resolve(
self,
target_name: str,
target_package: Optional[str] = None,
target_version: Optional[NodeVersion] = None,
) -> RelationProxy:
target_name = f"{self.model.name}__{target_name}"
return super().resolve(target_name, target_package, target_version)

Check warning on line 579 in core/dbt/context/providers.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/context/providers.py#L578-L579

Added lines #L578 - L579 were not covered by tests


# `source` implementations
class ParseSourceResolver(BaseSourceResolver):
def resolve(self, source_name: str, table_name: str):
Expand Down Expand Up @@ -670,6 +683,22 @@
pass


class UnitTestVar(RuntimeVar):
def __init__(
self,
context: Dict[str, Any],
config: RuntimeConfig,
node: Resource,
) -> None:
config_copy = None
assert isinstance(node, UnitTestNode)
if node.overrides and node.overrides.vars:
config_copy = deepcopy(config)
config_copy.cli_vars.update(node.overrides.vars)

Check warning on line 697 in core/dbt/context/providers.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/context/providers.py#L693-L697

Added lines #L693 - L697 were not covered by tests

super().__init__(context, config_copy or config, node=node)

Check warning on line 699 in core/dbt/context/providers.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/context/providers.py#L699

Added line #L699 was not covered by tests


# Providers
class Provider(Protocol):
execute: bool
Expand Down Expand Up @@ -711,6 +740,16 @@
metric = RuntimeMetricResolver


class RuntimeUnitTestProvider(Provider):
execute = True
Config = RuntimeConfigObject
DatabaseWrapper = RuntimeDatabaseWrapper
Var = UnitTestVar
ref = RuntimeUnitTestRefResolver
source = RuntimeSourceResolver # TODO: RuntimeUnitTestSourceResolver
metric = RuntimeMetricResolver


class OperationProvider(RuntimeProvider):
ref = OperationRefResolver

Expand Down Expand Up @@ -1359,7 +1398,7 @@

@contextproperty
def pre_hooks(self) -> List[Dict[str, Any]]:
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
if self.model.resource_type in [NodeType.Source, NodeType.Test, NodeType.Unit]:
return []
# TODO CT-211
return [
Expand All @@ -1368,7 +1407,7 @@

@contextproperty
def post_hooks(self) -> List[Dict[str, Any]]:
if self.model.resource_type in [NodeType.Source, NodeType.Test]:
if self.model.resource_type in [NodeType.Source, NodeType.Test, NodeType.Unit]:
return []
# TODO CT-211
return [
Expand Down Expand Up @@ -1461,6 +1500,25 @@
return None


class UnitTestContext(ModelContext):
model: UnitTestNode

@contextmember
def env_var(self, var: str, default: Optional[str] = None) -> str:
"""The env_var() function. Return the overriden unit test environment variable named 'var'.

If there is no unit test override, return the environment variable named 'var'.

If there is no such environment variable set, return the default.

If the default is None, raise an exception for an undefined variable.
"""
if self.model.overrides and var in self.model.overrides.env_vars:
return self.model.overrides.env_vars[var]

Check warning on line 1517 in core/dbt/context/providers.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/context/providers.py#L1516-L1517

Added lines #L1516 - L1517 were not covered by tests
else:
return super().env_var(var, default)

Check warning on line 1519 in core/dbt/context/providers.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/context/providers.py#L1519

Added line #L1519 was not covered by tests


# This is called by '_context_for', used in 'render_with_context'
def generate_parser_model_context(
model: ManifestNode,
Expand Down Expand Up @@ -1505,6 +1563,24 @@
return ctx.to_dict()


def generate_runtime_unit_test_context(
unit_test: UnitTestNode,
config: RuntimeConfig,
manifest: Manifest,
) -> Dict[str, Any]:
ctx = UnitTestContext(unit_test, config, manifest, RuntimeUnitTestProvider(), None)
ctx_dict = ctx.to_dict()

Check warning on line 1572 in core/dbt/context/providers.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/context/providers.py#L1571-L1572

Added lines #L1571 - L1572 were not covered by tests

if unit_test.overrides and unit_test.overrides.macros:
for macro_name, macro_value in unit_test.overrides.macros.items():
context_value = ctx_dict.get(macro_name)
if isinstance(context_value, MacroGenerator):
ctx_dict[macro_name] = UnitTestMacroGenerator(context_value, macro_value)

Check warning on line 1578 in core/dbt/context/providers.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/context/providers.py#L1574-L1578

Added lines #L1574 - L1578 were not covered by tests
else:
ctx_dict[macro_name] = macro_value
return ctx_dict

Check warning on line 1581 in core/dbt/context/providers.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/context/providers.py#L1580-L1581

Added lines #L1580 - L1581 were not covered by tests


class ExposureRefResolver(BaseResolver):
def __call__(self, *args, **kwargs) -> str:
package = None
Expand Down
1 change: 1 addition & 0 deletions core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ def finalize_and_validate(self):
NodeType.Source: SourceConfig,
NodeType.Seed: SeedConfig,
NodeType.Test: TestConfig,
NodeType.Unit: TestConfig,
NodeType.Model: NodeConfig,
NodeType.Snapshot: SnapshotConfig,
}
Expand Down
32 changes: 28 additions & 4 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
UnparsedSourceDefinition,
UnparsedSourceTableDefinition,
UnparsedColumn,
UnparsedUnitTestOverrides,
)
from dbt.contracts.graph.node_args import ModelNodeArgs
from dbt.contracts.util import Replaceable, AdditionalPropertiesMixin
Expand Down Expand Up @@ -942,13 +943,27 @@
namespace: Optional[str] = None


@dataclass
class UnitTestMetadata(dbtClassMixin, Replaceable):
# kwargs are the args that are left in the test builder after
# removing configs. They are set from the test builder when
# the test node is created.
kwargs: Dict[str, Any] = field(default_factory=dict)
namespace: Optional[str] = None


# This has to be separated out because it has no default and so
# has to be included as a superclass, not an attribute
@dataclass
class HasTestMetadata(dbtClassMixin):
test_metadata: TestMetadata


@dataclass
class HasUnitTestMetadata(dbtClassMixin):
unit_test_metadata: UnitTestMetadata


@dataclass
class GenericTestNode(TestShouldStoreFailures, CompiledNode, HasTestMetadata):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Test]})
Expand All @@ -970,6 +985,17 @@
return "generic"


@dataclass
class UnitTestNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Unit]})
attached_node: Optional[str] = None
overrides: Optional[UnparsedUnitTestOverrides] = None

@property
def test_node_type(self):
return "unit"

Check warning on line 996 in core/dbt/contracts/graph/nodes.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/contracts/graph/nodes.py#L996

Added line #L996 was not covered by tests


# ====================================
# Snapshot node
# ====================================
Expand Down Expand Up @@ -1628,6 +1654,7 @@
SqlNode,
GenericTestNode,
SnapshotNode,
UnitTestNode,
]

# All SQL nodes plus SeedNode (csv files)
Expand Down Expand Up @@ -1657,7 +1684,4 @@
Group,
]

TestNode = Union[
SingularTestNode,
GenericTestNode,
]
TestNode = Union[SingularTestNode, GenericTestNode]
Loading
Loading