Skip to content

Commit

Permalink
feature: kedro-airflow DAG kwarg configuration
Browse files Browse the repository at this point in the history
Signed-off-by: Simon Brugman <[email protected]>
  • Loading branch information
sbrugman committed Jul 12, 2023
1 parent 884c1fe commit c37ffbf
Show file tree
Hide file tree
Showing 7 changed files with 415 additions and 69 deletions.
55 changes: 55 additions & 0 deletions kedro-airflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,58 @@ You can use the additional command line argument `--jinja-file` (alias `-j`) to
```bash
kedro airflow create --jinja-file=./custom/template.j2
```

#### How can I pass arguments to the Airflow DAGs dynamically?

`kedro-airflow` picks up configuration from `airflow.yml` files. Arguments can be specified globally, or per pipeline:

```yaml
# Global parameters
default:
start_date: [2023, 1, 1]
max_active_runs: 3
# https://airflow.apache.org/docs/stable/scheduler.html#dag-runs
schedule_interval: "@once"
catchup: false
# Default settings applied to all tasks
owner: "airflow"
depends_on_past: false
email_on_failure: false
email_on_retry: false
retries: 1
retry_delay: 5

# Arguments specific to the pipeline (overrides the parameters above)
data_science:
owner: "airflow-ds"
```
Arguments can also be passed via `--params` in the command line:

```bash
kedro airflow create --params "schedule_interval='@weekly'"
```

These variables are passed to the Jinja2 template.

#### What if I want to pass different arguments?

In order to pass arguments other than those specified in the default template, simply pass a custom template (see: _"What if I want to use a different Jinja2 template?"_)

The syntax for arguments is:
```
{{ argument_name }}
```

In order to make arguments optional, one can use:
```
{{ argument_name | default("default_value") }}
```

For examples, please have a look at the default template (`airflow_dag_template.j2`).

#### How can I use Airflow runtime parameters?

It is possible to pass parameters when triggering an Airflow DAG from the user interface.
In order to use this feature, create a custom template using the [Params syntax](https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/params.html).
See _"What if I want to use a different Jinja2 template?"_ for instructions on using custom templates.
5 changes: 5 additions & 0 deletions kedro-airflow/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Upcoming release 0.5.2
* Change reference to `kedro.pipeline.Pipeline` object throughout test suite with `kedro.modular_pipeline.pipeline` factory.
* Migrate all project metadata to static `pyproject.toml`.
* Configure DAG kwargs via `airflow.yml`.
* The generated DAG file now contains the pipeline name.
* Included help for CLI arguments (see `kedro airflow create --help`).
* Raise error when pipeline does not exists.
* Added tests for CLI arugments.

# Release 0.5.1
* Added additional CLI argument `--jinja-file` to provide a path to a custom Jinja2 template.
Expand Down
58 changes: 27 additions & 31 deletions kedro-airflow/kedro_airflow/airflow_dag_template.j2
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ from kedro.framework.project import configure_project


class KedroOperator(BaseOperator):

@apply_defaults
def __init__(
self,
Expand All @@ -35,46 +34,43 @@ class KedroOperator(BaseOperator):
env=self.env) as session:
session.run(self.pipeline_name, node_names=[self.node_name])


# Kedro settings required to run your pipeline
env = "{{ env }}"
pipeline_name = "{{ pipeline_name }}"
project_path = Path.cwd()
package_name = "{{ package_name }}"

# Default settings applied to all tasks
default_args = {
'owner': 'airflow',
'depends_on_past': False,
'email_on_failure': False,
'email_on_retry': False,
'retries': 1,
'retry_delay': timedelta(minutes=5)
}

# Using a DAG context manager, you don't have to specify the dag property of each task
with DAG(
"{{ dag_name | safe | slugify }}",
start_date=datetime(2019, 1, 1),
max_active_runs=3,
schedule_interval=timedelta(minutes=30), # https://airflow.apache.org/docs/stable/scheduler.html#dag-runs
default_args=default_args,
catchup=False # enable if you don't want historical dag runs to run
) as dag:

tasks = {}
{% for node in pipeline.nodes %}
tasks["{{ node.name | safe | slugify }}"] = KedroOperator(
task_id="{{ node.name | safe | slugify }}",
package_name=package_name,
pipeline_name=pipeline_name,
node_name="{{ node.name | safe }}",
project_path=project_path,
env=env,
dag_id="{{ dag_name | safe | slugify }}",
start_date=datetime({{ start_date | default([2023, 1, 1]) | join(",")}}),
max_active_runs={{ max_active_runs | default(3) }},
# https://airflow.apache.org/docs/stable/scheduler.html#dag-runs
schedule_interval="{{ schedule_interval | default('@once') }}",
catchup={{ catchup | default(False) }},
# Default settings applied to all tasks
default_args=dict(
owner="{{ owner | default('airflow') }}",
depends_on_past={{ depends_on_past | default(False) }},
email_on_failure={{ email_on_failure | default(False) }},
email_on_retry={{ email_on_retry | default(False) }},
retries={{ retries | default(1) }},
retry_delay=timedelta(minutes={{ retry_delay | default(5) }})
)
{% endfor %}
) as dag:
tasks = {
{% for node in pipeline.nodes %} "{{ node.name | safe | slugify }}": KedroOperator(
task_id="{{ node.name | safe | slugify }}",
package_name=package_name,
pipeline_name=pipeline_name,
node_name="{{ node.name | safe }}",
project_path=project_path,
env=env,
),
{% endfor %} }

{% for parent_node, child_nodes in dependencies.items() -%}
{% for child in child_nodes %}
tasks["{{ parent_node.name | safe | slugify }}"] >> tasks["{{ child.name | safe | slugify }}"]
{% for child in child_nodes %} tasks["{{ parent_node.name | safe | slugify }}"] >> tasks["{{ child.name | safe | slugify }}"]
{% endfor %}
{%- endfor %}
71 changes: 65 additions & 6 deletions kedro-airflow/kedro_airflow/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,23 @@

from collections import defaultdict
from pathlib import Path
from typing import Any

import click
import jinja2
from click import secho
from kedro.config import MissingConfigException
from kedro.framework.cli.project import PARAMS_ARG_HELP
from kedro.framework.cli.utils import ENV_HELP, KedroCliError, _split_params
from kedro.framework.context import KedroContext
from kedro.framework.project import pipelines
from kedro.framework.startup import ProjectMetadata
from kedro.framework.session import KedroSession
from kedro.framework.startup import ProjectMetadata, bootstrap_project
from slugify import slugify

PIPELINE_ARG_HELP = """Name of the registered pipeline to convert.
If not set, the '__default__' pipeline is used."""


@click.group(name="Kedro-Airflow")
def commands(): # pylint: disable=missing-function-docstring
Expand All @@ -22,15 +31,35 @@ def airflow_commands():
pass


def load_config(
context: KedroContext, pipeline_name: str, patterns: list[str]
) -> dict[str, Any]:
try:
config_airflow = context.config_loader.get(*patterns)
dag_config = {}
# Load the default config if specified
if "default" in config_airflow:
dag_config.update(config_airflow["default"])
# Update with pipeline-specific config if present
if pipeline_name in config_airflow:
dag_config.update(config_airflow[pipeline_name])
except MissingConfigException:
dag_config = {}
return dag_config


@airflow_commands.command()
@click.option("-p", "--pipeline", "pipeline_name", default="__default__")
@click.option("-e", "--env", default="local")
@click.option(
"-p", "--pipeline", "pipeline_name", default="__default__", help=PIPELINE_ARG_HELP
)
@click.option("-e", "--env", default="local", help=ENV_HELP)
@click.option(
"-t",
"--target-dir",
"target_path",
type=click.Path(writable=True, resolve_path=True, file_okay=False),
default="./airflow_dags/",
help="The directory path to store the generated Airflow dags",
)
@click.option(
"-j",
Expand All @@ -39,6 +68,22 @@ def airflow_commands():
exists=True, readable=True, resolve_path=True, file_okay=True, dir_okay=False
),
default=Path(__file__).parent / "airflow_dag_template.j2",
help="The template file for the generated Airflow dags",
)
@click.option(
"--params",
type=click.UNPROCESSED,
default="",
help=PARAMS_ARG_HELP,
callback=_split_params,
)
@click.option(
"-c",
"--config-patterns",
multiple=True,
type=click.STRING,
default=["airflow*", "airflow/**"],
help="Config pattern for airflow.yml",
)
@click.pass_obj
def create(
Expand All @@ -47,23 +92,36 @@ def create(
env,
target_path,
jinja_file,
params,
config_patterns,
): # pylint: disable=too-many-locals,too-many-arguments
"""Create an Airflow DAG for a project"""
project_path = Path().cwd()
bootstrap_project(project_path)
with KedroSession.create(project_path=project_path, env=env) as session:
context = session.load_context()
dag_config = load_config(context, pipeline_name, config_patterns)

# Update with params if provided
dag_config.update(params)

jinja_file = Path(jinja_file).resolve()
loader = jinja2.FileSystemLoader(jinja_file.parent)
jinja_env = jinja2.Environment(autoescape=True, loader=loader, lstrip_blocks=True)
jinja_env.filters["slugify"] = slugify
template = jinja_env.get_template(jinja_file.name)

package_name = metadata.package_name
dag_filename = f"{package_name}_dag.py"
dag_filename = f"{package_name}_{pipeline_name}_dag.py"

target_path = Path(target_path)
target_path = target_path / dag_filename

target_path.parent.mkdir(parents=True, exist_ok=True)

pipeline = pipelines.get(pipeline_name)
if pipeline is None:
raise KedroCliError(f"Pipeline {pipeline_name} not found.")

dependencies = defaultdict(list)
for node, parent_nodes in pipeline.node_dependencies.items():
Expand All @@ -77,14 +135,16 @@ def create(
pipeline_name=pipeline_name,
package_name=package_name,
pipeline=pipeline,
**dag_config,
).dump(str(target_path))

secho("")
secho("An Airflow DAG has been generated in:", fg="green")
secho(str(target_path))
secho("This file should be copied to your Airflow DAG folder.", fg="yellow")
secho(
"The Airflow configuration can be customized by editing this file.", fg="green"
"The Airflow configuration can be customized by editing this file.",
fg="green",
)
secho("")
secho(
Expand All @@ -101,4 +161,3 @@ def create(
"And all local paths in both the data catalog and log config must be absolute paths.",
fg="yellow",
)
secho("")
1 change: 1 addition & 0 deletions kedro-airflow/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ pytest-mock
pytest-xdist
trufflehog>=2.1.0, <3.0
wheel
pyyaml
61 changes: 57 additions & 4 deletions kedro-airflow/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
discover them automatically. More info here:
https://docs.pytest.org/en/latest/fixture.html
"""
import os
from pathlib import Path
from shutil import copyfile

from click.testing import CliRunner
from cookiecutter.main import cookiecutter
from kedro import __version__ as kedro_version
from kedro.framework.cli.starters import TEMPLATE_PATH
from kedro.framework.startup import ProjectMetadata
from pytest import fixture


@fixture(name="cli_runner")
@fixture(name="cli_runner", scope="session")
def cli_runner():
runner = CliRunner()
cwd = Path.cwd()
Expand All @@ -23,10 +26,60 @@ def cli_runner():
yield runner


@fixture
def metadata(cli_runner): # pylint: disable=unused-argument
@fixture(scope="session")
def kedro_project(cli_runner):
tmp_path = Path().cwd()
# From `kedro-mlflow.tests.conftest.py`
config = {
"output_dir": tmp_path,
"kedro_version": kedro_version,
"project_name": "This is a fake project",
"repo_name": "fake-project",
"python_package": "fake_project",
"include_example": True,
}

cookiecutter(
str(TEMPLATE_PATH),
output_dir=config["output_dir"],
no_input=True,
extra_context=config,
)

pipeline_registry_py = """
from kedro.pipeline import Pipeline, node
def identity(arg):
return arg
def register_pipelines():
pipeline = Pipeline(
[
node(identity, ["input"], ["intermediate"], name="node0"),
node(identity, ["intermediate"], ["output"], name="node1"),
],
tags="pipeline0",
)
return {
"__default__": pipeline,
"ds": pipeline,
}
"""

(
tmp_path / "fake-project" / "src" / "fake_project" / "pipeline_registry.py"
).write_text(pipeline_registry_py)

os.chdir(tmp_path / "fake-project")
return tmp_path / "fake-project"


@fixture(scope="session")
def metadata(kedro_project): # pylint: disable=unused-argument
# cwd() depends on ^ the isolated filesystem, created by CliRunner()
project_path = Path.cwd()
project_path = kedro_project
return ProjectMetadata(
project_path / "pyproject.toml",
"hello_world",
Expand Down
Loading

0 comments on commit c37ffbf

Please sign in to comment.