diff --git a/kedro-airflow/kedro_airflow/airflow_dag_kwargs_template.j2 b/kedro-airflow/kedro_airflow/airflow_dag_kwargs_template.j2 new file mode 100644 index 000000000..5cb88ef81 --- /dev/null +++ b/kedro-airflow/kedro_airflow/airflow_dag_kwargs_template.j2 @@ -0,0 +1,14 @@ +start_date=datetime(2023, 1, 1), +max_active_runs=3, +# https://airflow.apache.org/docs/stable/scheduler.html#dag-runs +schedule_interval="@once", +catchup=False, +# Default settings applied to all tasks +default_args=dict( + owner="{{ default_args['owner'] }}", + depends_on_past=False, + email_on_failure=False, + email_on_retry=False, + retries=1, + retry_delay=timedelta(minutes=5) +) diff --git a/kedro-airflow/kedro_airflow/airflow_dag_template.j2 b/kedro-airflow/kedro_airflow/airflow_dag_template.j2 index 0538c5bdd..e37cb3e2c 100644 --- a/kedro-airflow/kedro_airflow/airflow_dag_template.j2 +++ b/kedro-airflow/kedro_airflow/airflow_dag_template.j2 @@ -44,9 +44,10 @@ package_name = "{{ package_name }}" # Using a DAG context manager, you don't have to specify the dag property of each task with DAG( dag_id="{{ dag_name | safe | slugify }}", - {% if kwargs %}{% filter indent(width=4) %} - {{ kwargs | safe }} - {% endfilter %}{% endif -%} + {% if kwargs_template -%} + {% filter indent(width=4) %} {% include kwargs_template with context -%} + {% endfilter %} + {%- endif %} ) as dag: tasks = { {% for node in pipeline.nodes %} "{{ node.name | safe | slugify }}": KedroOperator( @@ -57,8 +58,8 @@ with DAG( project_path=project_path, env=env, ), -{% endfor %} - } +{% endfor %} } + {% for parent_node, child_nodes in dependencies.items() -%} {% for child in child_nodes %} tasks["{{ parent_node.name | safe | slugify }}"] >> tasks["{{ child.name | safe | slugify }}"] {% endfor %} diff --git a/kedro-airflow/kedro_airflow/plugin.py b/kedro-airflow/kedro_airflow/plugin.py index d23e6bcd6..6f717e8bf 100644 --- a/kedro-airflow/kedro_airflow/plugin.py +++ b/kedro-airflow/kedro_airflow/plugin.py @@ -7,7 +7,8 @@ import jinja2 from click import secho from kedro.config import MissingConfigException -from kedro.framework.cli.utils import ENV_HELP +from kedro.framework.cli.project import PARAMS_ARG_HELP +from kedro.framework.cli.utils import ENV_HELP, KedroCliError, _split_params from kedro.framework.project import pipelines from kedro.framework.session import KedroSession from kedro.framework.startup import ProjectMetadata, bootstrap_project @@ -40,7 +41,7 @@ def airflow_commands(): "target_path", type=click.Path(writable=True, resolve_path=True, file_okay=False), default="./airflow_dags/", - help="The directory path to store the generated Airflow dags" + help="The directory path to store the generated Airflow dags", ) @click.option( "-j", @@ -49,7 +50,23 @@ def airflow_commands(): exists=True, readable=True, resolve_path=True, file_okay=True, dir_okay=False ), default=Path(__file__).parent / "airflow_dag_template.j2", - help="The template file for the generated Airflow dags" + help="The template file for the generated Airflow dags", +) +@click.option( + "-k", + "--kwargs-template", + type=click.Path( + exists=True, readable=True, resolve_path=True, file_okay=True, dir_okay=False + ), + default=Path(__file__).parent / "airflow_dag_kwargs_template.j2", + help="The template file for the kwargs in the Airflow dags", +) +@click.option( + "--params", + type=click.UNPROCESSED, + default="", + help=PARAMS_ARG_HELP, + callback=_split_params, ) @click.pass_obj def create( @@ -58,6 +75,8 @@ def create( env, target_path, jinja_file, + kwargs_template, + params, ): # pylint: disable=too-many-locals,too-many-arguments """Create an Airflow DAG for a project""" project_path = Path().cwd() @@ -66,20 +85,30 @@ def create( context = session.load_context() try: config_airflow = context.config_loader.get("airflow*", "airflow/**") + dag_config = {} + # Load the default config if specified + if "default" in config_airflow: + dag_config.update(config_airflow["default"]) + # Update with pipeline-specific config if present if pipeline_name in config_airflow: - kwargs = config_airflow[pipeline_name] - else: - kwargs = config_airflow["kwargs"] + dag_config.update(config_airflow[pipeline_name]) except MissingConfigException: - kwargs = {} + dag_config = {} + + # Update with params if provided + dag_config.update(params) jinja_file = Path(jinja_file).resolve() + kwargs_template = Path(kwargs_template).resolve() + if jinja_file.parent != kwargs_template.parent: + raise KedroCliError(f"Templates should be placed in the same directory.") loader = jinja2.FileSystemLoader(jinja_file.parent) jinja_env = jinja2.Environment( - autoescape=True, loader=loader, lstrip_blocks=True, trim_blocks=True + autoescape=True, loader=loader, lstrip_blocks=True # , trim_blocks=True ) jinja_env.filters["slugify"] = slugify template = jinja_env.get_template(jinja_file.name) + kwargs_template = kwargs_template.name package_name = metadata.package_name dag_filename = f"{package_name}_{pipeline_name}_dag.py" @@ -90,6 +119,8 @@ def create( target_path.parent.mkdir(parents=True, exist_ok=True) pipeline = pipelines.get(pipeline_name) + if pipeline is None: + raise KedroCliError(f"Pipeline {pipeline_name} not found.") dependencies = defaultdict(list) for node, parent_nodes in pipeline.node_dependencies.items(): @@ -103,7 +134,8 @@ def create( pipeline_name=pipeline_name, package_name=package_name, pipeline=pipeline, - kwargs=kwargs, + kwargs_template=kwargs_template, + **dag_config, ).dump(str(target_path)) secho("") diff --git a/kedro-airflow/pyproject.toml b/kedro-airflow/pyproject.toml index 42fe8974b..6787a54e8 100644 --- a/kedro-airflow/pyproject.toml +++ b/kedro-airflow/pyproject.toml @@ -31,7 +31,10 @@ packages = ["kedro_airflow"] zip-safe = false [tool.setuptools.package-data] -kedro_airflow = ["kedro_airflow/airflow_dag_template.j2"] +kedro_airflow = [ + "kedro_airflow/airflow_dag_template.j2", + "kedro_airflow/airflow_dag_kwargs_template.j2" +] [tool.setuptools.dynamic] readme = {file = "README.md", content-type = "text/markdown"}