Skip to content

Commit

Permalink
feat: config template
Browse files Browse the repository at this point in the history
Signed-off-by: Simon Brugman <[email protected]>
  • Loading branch information
sbrugman committed Jun 13, 2023
1 parent e026c06 commit 8e5bc37
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 15 deletions.
14 changes: 14 additions & 0 deletions kedro-airflow/kedro_airflow/airflow_dag_kwargs_template.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
start_date=datetime(2023, 1, 1),
max_active_runs=3,
# https://airflow.apache.org/docs/stable/scheduler.html#dag-runs
schedule_interval="@once",
catchup=False,
# Default settings applied to all tasks
default_args=dict(
owner="{{ default_args['owner'] }}",
depends_on_past=False,
email_on_failure=False,
email_on_retry=False,
retries=1,
retry_delay=timedelta(minutes=5)
)
11 changes: 6 additions & 5 deletions kedro-airflow/kedro_airflow/airflow_dag_template.j2
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ package_name = "{{ package_name }}"
# Using a DAG context manager, you don't have to specify the dag property of each task
with DAG(
dag_id="{{ dag_name | safe | slugify }}",
{% if kwargs %}{% filter indent(width=4) %}
{{ kwargs | safe }}
{% endfilter %}{% endif -%}
{% if kwargs_template -%}
{% filter indent(width=4) %} {% include kwargs_template with context -%}
{% endfilter %}
{%- endif %}
) as dag:
tasks = {
{% for node in pipeline.nodes %} "{{ node.name | safe | slugify }}": KedroOperator(
Expand All @@ -57,8 +58,8 @@ with DAG(
project_path=project_path,
env=env,
),
{% endfor %}
}
{% endfor %} }

{% for parent_node, child_nodes in dependencies.items() -%}
{% for child in child_nodes %} tasks["{{ parent_node.name | safe | slugify }}"] >> tasks["{{ child.name | safe | slugify }}"]
{% endfor %}
Expand Down
50 changes: 41 additions & 9 deletions kedro-airflow/kedro_airflow/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import jinja2
from click import secho
from kedro.config import MissingConfigException
from kedro.framework.cli.utils import ENV_HELP
from kedro.framework.cli.project import PARAMS_ARG_HELP
from kedro.framework.cli.utils import ENV_HELP, KedroCliError, _split_params
from kedro.framework.project import pipelines
from kedro.framework.session import KedroSession
from kedro.framework.startup import ProjectMetadata, bootstrap_project
Expand Down Expand Up @@ -40,7 +41,7 @@ def airflow_commands():
"target_path",
type=click.Path(writable=True, resolve_path=True, file_okay=False),
default="./airflow_dags/",
help="The directory path to store the generated Airflow dags"
help="The directory path to store the generated Airflow dags",
)
@click.option(
"-j",
Expand All @@ -49,7 +50,23 @@ def airflow_commands():
exists=True, readable=True, resolve_path=True, file_okay=True, dir_okay=False
),
default=Path(__file__).parent / "airflow_dag_template.j2",
help="The template file for the generated Airflow dags"
help="The template file for the generated Airflow dags",
)
@click.option(
"-k",
"--kwargs-template",
type=click.Path(
exists=True, readable=True, resolve_path=True, file_okay=True, dir_okay=False
),
default=Path(__file__).parent / "airflow_dag_kwargs_template.j2",
help="The template file for the kwargs in the Airflow dags",
)
@click.option(
"--params",
type=click.UNPROCESSED,
default="",
help=PARAMS_ARG_HELP,
callback=_split_params,
)
@click.pass_obj
def create(
Expand All @@ -58,6 +75,8 @@ def create(
env,
target_path,
jinja_file,
kwargs_template,
params,
): # pylint: disable=too-many-locals,too-many-arguments
"""Create an Airflow DAG for a project"""
project_path = Path().cwd()
Expand All @@ -66,20 +85,30 @@ def create(
context = session.load_context()
try:
config_airflow = context.config_loader.get("airflow*", "airflow/**")
dag_config = {}
# Load the default config if specified
if "default" in config_airflow:
dag_config.update(config_airflow["default"])
# Update with pipeline-specific config if present
if pipeline_name in config_airflow:
kwargs = config_airflow[pipeline_name]
else:
kwargs = config_airflow["kwargs"]
dag_config.update(config_airflow[pipeline_name])
except MissingConfigException:
kwargs = {}
dag_config = {}

# Update with params if provided
dag_config.update(params)

jinja_file = Path(jinja_file).resolve()
kwargs_template = Path(kwargs_template).resolve()
if jinja_file.parent != kwargs_template.parent:
raise KedroCliError(f"Templates should be placed in the same directory.")
loader = jinja2.FileSystemLoader(jinja_file.parent)
jinja_env = jinja2.Environment(
autoescape=True, loader=loader, lstrip_blocks=True, trim_blocks=True
autoescape=True, loader=loader, lstrip_blocks=True # , trim_blocks=True
)
jinja_env.filters["slugify"] = slugify
template = jinja_env.get_template(jinja_file.name)
kwargs_template = kwargs_template.name

package_name = metadata.package_name
dag_filename = f"{package_name}_{pipeline_name}_dag.py"
Expand All @@ -90,6 +119,8 @@ def create(
target_path.parent.mkdir(parents=True, exist_ok=True)

pipeline = pipelines.get(pipeline_name)
if pipeline is None:
raise KedroCliError(f"Pipeline {pipeline_name} not found.")

dependencies = defaultdict(list)
for node, parent_nodes in pipeline.node_dependencies.items():
Expand All @@ -103,7 +134,8 @@ def create(
pipeline_name=pipeline_name,
package_name=package_name,
pipeline=pipeline,
kwargs=kwargs,
kwargs_template=kwargs_template,
**dag_config,
).dump(str(target_path))

secho("")
Expand Down
5 changes: 4 additions & 1 deletion kedro-airflow/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ packages = ["kedro_airflow"]
zip-safe = false

[tool.setuptools.package-data]
kedro_airflow = ["kedro_airflow/airflow_dag_template.j2"]
kedro_airflow = [
"kedro_airflow/airflow_dag_template.j2",
"kedro_airflow/airflow_dag_kwargs_template.j2"
]

[tool.setuptools.dynamic]
readme = {file = "README.md", content-type = "text/markdown"}
Expand Down

0 comments on commit 8e5bc37

Please sign in to comment.