Skip to content

Commit

Permalink
[FEATURE] 163 - For each task (#185)
Browse files Browse the repository at this point in the history
* Support for foreach task

* tests

* refactor: DatabricksBundleCodegen to improve task handling and streamline workflow task generation

* fix: update task builder function to use nested task type for improved accuracy

* Builders in _get_task_builder can be resolved also from task class

* Fix tests and serialization issue with ForEachTask custom class

* Support for brickflow task type in for each task

* Setting push_return_value in task response to False if task type is for each task

* Refactoring for each task build function to allow for validation of task type also in case of nested task being of Brickflow type

* Support for spark jar task type in for each task

* Tests for run job task type in for each task

* Tests for sql task type in for each task

* feat: add JobsTasksForEachTaskConfigs for improved task configuration and validation

* feat: add JobsTasksForEachTaskConfigs for for-each task configuration and validation

* feat: implement model validation for ForEachTask configuration inputs and concurrency

* feat: simplify ForEachTask initialization by using task configuration object

* fix format

* fix format

* Removed TODO, nested task name is not exposed

* Documentation of for each task type

* For each task examples

* Moved for each task config validation up

* Updates to doc and examples after introduction of for each task config model

* chore: update Python version to 3.9 and refine configuration files

* chore: update documentation dependencies and versions in pyproject.toml

* Fixed formatting of bullet points in doc

* Update brickflow project conf in for each examples

---------

Co-authored-by: Mikita Sakalouski <[email protected]>
  • Loading branch information
riccamini and mikita-sakalouski authored Dec 18, 2024
1 parent 78154e4 commit c652522
Show file tree
Hide file tree
Showing 24 changed files with 2,877 additions and 1,814 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/onpush.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
max-parallel: 2
matrix:
python-version: [ '3.8' ]
python-version: [ '3.9' ]
os: [ ubuntu-latest ]

steps:
Expand Down
7 changes: 5 additions & 2 deletions brickflow/bundles/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, Extra, Field, constr
from pydantic import BaseModel, Field, constr, InstanceOf
from typing_extensions import Literal


Expand Down Expand Up @@ -1174,7 +1174,10 @@ class Config:


class JobsTasksForEachTask(BaseModel):
pass
inputs: str
concurrency: int
task: JobsTasks



class JobsTasksHealthRules(BaseModel):
Expand Down
191 changes: 148 additions & 43 deletions brickflow/codegen/databricks_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@
)
from brickflow.engine.task import (
DLTPipeline,
ForEachTask,
IfElseConditionTask,
NotebookTask,
RunJobTask,
SparkJarTask,
SparkPythonTask,
SqlTask,
TaskLibrary,
TaskSettings,
filter_bf_related_libraries,
Expand Down Expand Up @@ -498,12 +505,11 @@ def adjust_file_path(self, file_path: str) -> str:
return file_path

def task_to_task_obj(self, task: Task) -> JobsTasksNotebookTask:
if task.task_type in [TaskType.BRICKFLOW_TASK, TaskType.CUSTOM_PYTHON_TASK]:
generated_path = handle_mono_repo_path(self.project, self.env)
return JobsTasksNotebookTask(
**task.get_obj_dict(generated_path),
source=self.adjust_source(),
)
generated_path = handle_mono_repo_path(self.project, self.env)
return JobsTasksNotebookTask(
**task.get_obj_dict(generated_path),
source=self.adjust_source(),
)

def workflow_obj_to_pipelines(self, workflow: Workflow) -> Dict[str, Pipelines]:
pipelines_dict = {}
Expand Down Expand Up @@ -760,6 +766,70 @@ def _build_dlt_task(
task_key=task_name,
)

def _build_native_for_each_task(
self,
task_name: str,
task: Task,
task_libraries: List[JobsTasksLibraries],
task_settings: TaskSettings,
depends_on: List[JobsTasksDependsOn],
**kwargs: Any,
) -> JobsTasks:
supported_task_types = (
TaskType.NOTEBOOK_TASK,
TaskType.SPARK_JAR_TASK,
TaskType.SPARK_PYTHON_TASK,
TaskType.RUN_JOB_TASK,
TaskType.SQL,
TaskType.BRICKFLOW_TASK, # Accounts for brickflow entrypoint tasks
)

if task.for_each_task_conf is None:
raise ValueError(
f"Error while building for each task {task_name}. "
f"Make sure {task_name} has a for_each_task_conf attribute."
)

nested_task = task.task_func()
task_type = self._get_task_type(nested_task)

try:
assert task_type in supported_task_types
except AssertionError as e:
raise ValueError(
f"Error while building python task {task_name}. Make sure {task_name} is one of "
f"{', '.join(task_type.__name__ for task_type in supported_task_types)}."
) from e

builder_func = self._get_task_builder(task_type=task_type)

workflow: Optional[Workflow] = kwargs.get("workflow")
# Currently the inner task name is not exposed, will have to add a parammeter to the for_each_task decorator to
# allow user to configure it
nested_task_jt = builder_func(
task_name=f"{task_name}_nested",
task=task,
workflow=workflow,
task_libraries=task_libraries,
task_settings=task_settings,
depends_on=[],
)

for_each_task = ForEachTask(
configs=task.for_each_task_conf,
task=nested_task_jt,
)

# We are not specifying any cluster or libraries as for_each_task cannot have them!
jt = JobsTasks(
**task_settings.to_tf_dict(),
for_each_task=for_each_task,
depends_on=depends_on,
task_key=task_name,
)

return jt

def _build_brickflow_entrypoint_task(
self,
task_name: str,
Expand All @@ -771,7 +841,7 @@ def _build_brickflow_entrypoint_task(
) -> JobsTasks:
task_obj = JobsTasks(
**{
task.databricks_task_type_str: self.task_to_task_obj(task),
TaskType.NOTEBOOK_TASK.value: self.task_to_task_obj(task),
**task_settings.to_tf_dict(),
}, # type: ignore
depends_on=depends_on,
Expand All @@ -791,56 +861,91 @@ def _build_brickflow_entrypoint_task(
)
return task_obj

def workflow_obj_to_tasks(
self, workflow: Workflow
) -> List[Union[JobsTasks, Pipelines]]:
tasks = []
def _get_task_type(self, task: Any) -> TaskType:
"""Resolves the task type given the task object"""

map_task_class_to_task_type: Dict[typing.Type, TaskType] = {
DLTPipeline: TaskType.DLT,
NotebookTask: TaskType.NOTEBOOK_TASK,
SparkJarTask: TaskType.SPARK_JAR_TASK,
SparkPythonTask: TaskType.SPARK_PYTHON_TASK,
RunJobTask: TaskType.RUN_JOB_TASK,
SqlTask: TaskType.SQL,
IfElseConditionTask: TaskType.IF_ELSE_CONDITION_TASK,
ForEachTask: TaskType.FOR_EACH_TASK,
}

# Brickflow tasks does not have a dedicated task class, so we are matching everything else with it
return map_task_class_to_task_type.get(type(task), TaskType.BRICKFLOW_TASK)

def _get_task_builder(self, task_type: TaskType = None) -> Callable[..., Any]:
map_task_type_to_builder: Dict[TaskType, Callable[..., Any]] = {
TaskType.BRICKFLOW_TASK: self._build_brickflow_entrypoint_task,
TaskType.DLT: self._build_dlt_task,
TaskType.NOTEBOOK_TASK: self._build_native_notebook_task,
TaskType.SPARK_JAR_TASK: self._build_native_spark_jar_task,
TaskType.SPARK_PYTHON_TASK: self._build_native_spark_python_task,
TaskType.RUN_JOB_TASK: self._build_native_run_job_task,
TaskType.SQL: self._build_native_sql_file_task,
TaskType.IF_ELSE_CONDITION_TASK: self._build_native_condition_task,
TaskType.FOR_EACH_TASK: self._build_native_for_each_task,
TaskType.CUSTOM_PYTHON_TASK: self._build_brickflow_entrypoint_task,
}

for task_name, task in workflow.tasks.items():
builder_func = map_task_type_to_builder.get(
task.task_type, self._build_brickflow_entrypoint_task
)
builder = map_task_type_to_builder.get(task_type, None)
if builder is None:
raise ValueError("No builder found for the given task or task class")
return builder

def _build_task(
self, build_func: Callable, workflow: Workflow, task_name: str, task: Task
) -> Union[JobsTasks, Pipelines]:
# TODO: DLT
# pipeline_task: Pipeline = self._create_dlt_notebooks(stack, task)
if task.depends_on_names:
depends_on = [
JobsTasksDependsOn(task_key=depends_key, outcome=expected_outcome)
for i in task.depends_on_names
for depends_key, expected_outcome in i.items()
] # type: ignore
else:
depends_on = []

# TODO: DLT
# pipeline_task: Pipeline = self._create_dlt_notebooks(stack, task)
if task.depends_on_names:
depends_on = [
JobsTasksDependsOn(task_key=depends_key, outcome=expected_outcome)
for i in task.depends_on_names
for depends_key, expected_outcome in i.items()
] # type: ignore
else:
depends_on = []
libraries = TaskLibrary.unique_libraries(
task.libraries + (self.project.libraries or [])
)
if workflow.enable_plugins is True:
libraries = filter_bf_related_libraries(libraries)
libraries += get_brickflow_libraries(workflow.enable_plugins)
libraries = TaskLibrary.unique_libraries(
task.libraries + (self.project.libraries or [])
)
if workflow.enable_plugins is True:
libraries = filter_bf_related_libraries(libraries)
libraries += get_brickflow_libraries(workflow.enable_plugins)

task_libraries = [JobsTasksLibraries(**library.dict) for library in libraries] # type: ignore
task_settings = workflow.default_task_settings.merge(task.task_settings) # type: ignore
task = build_func(
task_name=task_name,
task=task,
workflow=workflow,
task_libraries=task_libraries,
task_settings=task_settings,
depends_on=depends_on,
)

task_libraries = [
JobsTasksLibraries(**library.dict) for library in libraries
] # type: ignore
task_settings = workflow.default_task_settings.merge(task.task_settings) # type: ignore
task = builder_func(
task_name=task_name,
task=task,
workflow=workflow,
task_libraries=task_libraries,
task_settings=task_settings,
depends_on=depends_on,
return task

def workflow_obj_to_tasks(
self, workflow: Workflow
) -> List[Union[JobsTasks, Pipelines]]:
tasks = []

for task_name, task in workflow.tasks.items():
build_func = self._get_task_builder(task_type=task.task_type)
tasks.append(
self._build_task(
build_func=build_func,
workflow=workflow,
task_name=task_name,
task=task,
)
)
tasks.append(task)

tasks.sort(key=lambda t: (t.task_key is None, t.task_key))

Expand Down
45 changes: 44 additions & 1 deletion brickflow/engine/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import pluggy
from decouple import config
from pydantic import BaseModel, Field, field_validator, model_validator

from brickflow import (
BrickflowDefaultEnvs,
Expand All @@ -38,6 +39,7 @@
)
from brickflow.bundles.model import (
JobsTasksConditionTask,
JobsTasksForEachTask,
JobsTasksHealthRules,
JobsTasksNotebookTask,
JobsTasksNotificationSettings,
Expand Down Expand Up @@ -123,6 +125,7 @@ class TaskType(Enum):
SPARK_PYTHON_TASK = "spark_python_task"
RUN_JOB_TASK = "run_job_task"
IF_ELSE_CONDITION_TASK = "condition_task"
FOR_EACH_TASK = "for_each_task"


class TaskRunCondition(Enum):
Expand Down Expand Up @@ -493,6 +496,44 @@ def __init__(self, **kwargs: Any) -> None:
self.python_file = kwargs.get("python_file", None)


class JobsTasksForEachTaskConfigs(BaseModel):
inputs: str = Field(..., description="The input data for the task.")
concurrency: int = Field(
default=1, description="Number of iterations that can run in parallel,"
)

@field_validator("inputs", mode="before")
@classmethod
def validate_inputs(cls, inputs: Any) -> str:
if not isinstance(inputs, str):
inputs = json.dumps(inputs)
return inputs


class ForEachTask(JobsTasksForEachTask):
"""
The ForEachTask class provides iteration of a task over a list of inputs. The looped task can be executed
concurrently based on the concurrency value provided.
Attributes:
inputs (str): Array for task to iterate on. This can be a JSON string or a reference to an array parameter.
concurrency (int): An optional maximum allowed number of concurrent runs of the task. Set this value if you want
to be able to execute multiple runs of the task concurrently
task (Any): The task that will be run for each element in the array
"""

configs: JobsTasksForEachTaskConfigs
task: Any

@model_validator(mode="before")
def validate_configs(self) -> "ForEachTask":
self["inputs"] = self["configs"].inputs # type: ignore
self["concurrency"] = self["configs"].concurrency # type: ignore

return self


class RunJobTask(JobsTasksRunJobTask):
"""
The RunJobTask class is designed to handle the execution of a specific job in a Databricks workspace.
Expand Down Expand Up @@ -704,10 +745,11 @@ def task_execute(task: "Task", workflow: "Workflow") -> TaskResponse:
else:
kwargs = task.get_runtime_parameter_values()
try:
# Task return value cannot be pushed if we are in a for each task (now allowed by Databricks)
return TaskResponse(
task.task_func(**kwargs),
user_code_error=None,
push_return_value=True,
push_return_value=not task.task_type == TaskType.FOR_EACH_TASK,
input_kwargs=kwargs,
)
except Exception as e:
Expand Down Expand Up @@ -807,6 +849,7 @@ class Task:
ensure_brickflow_plugins: bool = False
health: Optional[List[JobsTasksHealthRules]] = None
if_else_outcome: Optional[Dict[Union[str, str], str]] = None
for_each_task_conf: Optional[JobsTasksForEachTaskConfigs] = None

def __post_init__(self) -> None:
self.is_valid_task_signature()
Expand Down
Loading

0 comments on commit c652522

Please sign in to comment.