diff --git a/aiida_workgraph/calculations/python_parser.py b/aiida_workgraph/calculations/python_parser.py index 4284f6ab..17d71141 100644 --- a/aiida_workgraph/calculations/python_parser.py +++ b/aiida_workgraph/calculations/python_parser.py @@ -1,6 +1,7 @@ """Parser for an `PythonJob` job.""" from aiida.parsers.parser import Parser from aiida_workgraph.orm import general_serializer +from aiida.engine import ExitCode class PythonParser(Parser): @@ -31,13 +32,14 @@ def parse(self, **kwargs): "remote_folder", "remote_stash", "retrieved", + "exit_code", ] ] # first we remove nested outputs, e.g., "add_multiply.add" top_level_output_list = [ output for output in self.output_list if "." not in output["name"] ] - + exit_code = 0 try: with self.retrieved.base.repository.open("results.pickle", "rb") as handle: results = pickle.load(handle) @@ -49,6 +51,8 @@ def parse(self, **kwargs): results[i], top_level_output_list[i] ) elif isinstance(results, dict) and len(top_level_output_list) > 1: + # pop the exit code if it exists + exit_code = results.pop("exit_code", 0) for output in top_level_output_list: if output.get("required", False): if output["name"] not in results: @@ -62,6 +66,7 @@ def parse(self, **kwargs): f"Found extra results that are not included in the output: {results.keys()}" ) elif isinstance(results, dict) and len(top_level_output_list) == 1: + exit_code = results.pop("exit_code", 0) # if output name in results, use it if top_level_output_list[0]["name"] in results: top_level_output_list[0]["value"] = self.serialize_output( @@ -84,6 +89,12 @@ def parse(self, **kwargs): ) for output in top_level_output_list: self.out(output["name"], output["value"]) + if exit_code: + if isinstance(exit_code, dict): + exit_code = ExitCode(exit_code["status"], exit_code["message"]) + elif isinstance(exit_code, int): + exit_code = ExitCode(exit_code) + return exit_code except OSError: return self.exit_codes.ERROR_READING_OUTPUT_FILE except ValueError as exception: diff --git a/aiida_workgraph/decorator.py b/aiida_workgraph/decorator.py index 2f925230..0af0173b 100644 --- a/aiida_workgraph/decorator.py +++ b/aiida_workgraph/decorator.py @@ -280,6 +280,7 @@ def build_task_from_AiiDA( def build_pythonjob_task(func: Callable) -> Task: """Build PythonJob task from function.""" from aiida_workgraph.calculations.python import PythonJob + from aiida_workgraph.tasks.pythonjob import PythonJob as PythonJobTask from copy import deepcopy # if the function is not a task, build a task from the function @@ -310,6 +311,7 @@ def build_pythonjob_task(func: Callable) -> Task: for output in tdata_py["outputs"]: if output not in outputs: outputs.append(output) + outputs.append({"identifier": "workgraph.any", "name": "exit_code"}) # change "copy_files" link_limit to 1e6 for input in inputs: if input["name"] == "copy_files": @@ -322,6 +324,8 @@ def build_pythonjob_task(func: Callable) -> Task: tdata["outputs"] = outputs tdata["kwargs"] = kwargs tdata["task_type"] = "PYTHONJOB" + tdata["identifier"] = "workgraph.pythonjob" + tdata["node_class"] = PythonJobTask task = create_task(tdata) task.is_aiida_component = True return task, tdata diff --git a/aiida_workgraph/engine/workgraph.py b/aiida_workgraph/engine/workgraph.py index f2eacd23..ba0afcc6 100644 --- a/aiida_workgraph/engine/workgraph.py +++ b/aiida_workgraph/engine/workgraph.py @@ -532,12 +532,20 @@ def update_workgraph_from_base(self) -> None: def get_task(self, name: str): """Get task from the context.""" task = Task.from_dict(self.ctx._tasks[name]) + # update task results + for output in task.outputs: + output.value = get_nested_dict( + self.ctx._tasks[name]["results"], + output.name, + default=output.value, + ) return task def update_task(self, task: Task): """Update task in the context. This is used in error handlers to update the task parameters.""" - self.ctx._tasks[task.name]["properties"] = task.properties_to_dict() + tdata = task.to_dict() + self.ctx._tasks[task.name]["properties"] = tdata["properties"] self.reset_task(task.name) def get_task_state_info(self, name: str, key: str) -> str: diff --git a/aiida_workgraph/task.py b/aiida_workgraph/task.py index 70dcd23a..55ca15b2 100644 --- a/aiida_workgraph/task.py +++ b/aiida_workgraph/task.py @@ -112,11 +112,15 @@ def from_dict(cls, data: Dict[str, Any], task_pool: Optional[Any] = None) -> "Ta Returns: Node: An instance of Node initialized with the provided data.""" from aiida_workgraph.tasks import task_pool + from aiida.orm.utils.serialize import deserialize_unsafe - task = super().from_dict(data, node_pool=task_pool) + task = GraphNode.from_dict(data, node_pool=task_pool) task.context_mapping = data.get("context_mapping", {}) task.waiting_on.add(data.get("wait", [])) - task.process = data.get("process", None) + process = data.get("process", None) + if process and isinstance(process, str): + process = deserialize_unsafe(process) + task.process = process return task diff --git a/aiida_workgraph/tasks/pythonjob.py b/aiida_workgraph/tasks/pythonjob.py new file mode 100644 index 00000000..0e93ea64 --- /dev/null +++ b/aiida_workgraph/tasks/pythonjob.py @@ -0,0 +1,69 @@ +from typing import Any, Dict +from aiida import orm +from aiida_workgraph.orm.serializer import general_serializer +from aiida_workgraph.task import Task + + +class PythonJob(Task): + """PythonJob Task.""" + + identifier = "workgraph.pythonjob" + + @classmethod + def get_function_kwargs(cls, data) -> Dict[str, Any]: + input_kwargs = set() + for name in data["metadata"]["kwargs"]: + # all the kwargs are after computer is the input for the PythonJob, should be AiiDA Data node + if name == "computer": + break + input_kwargs.add(name) + return input_kwargs + + def update_from_dict(cls, data: Dict[str, Any], **kwargs) -> "PythonJob": + """Overwrite the update_from_dict method to handle the PythonJob data.""" + cls.deserialize_pythonjob_data(data) + return super().update_from_dict(data) + + def to_dict(self) -> Dict[str, Any]: + data = super().to_dict() + self.serialize_pythonjob_data(data) + return data + + @classmethod + def serialize_pythonjob_data(cls, tdata: Dict[str, Any]): + """Serialize the properties for PythonJob.""" + + input_kwargs = cls.get_function_kwargs(tdata) + for name in input_kwargs: + prop = tdata["properties"][name] + # if value is not None, not {} + if not ( + prop["value"] is None + or isinstance(prop["value"], dict) + and prop["value"] == {} + ): + prop["value"] = general_serializer(prop["value"]) + + @classmethod + def deserialize_pythonjob_data(cls, tdata: Dict[str, Any]) -> None: + """ + Process the task data dictionary for a PythonJob. + It load the orignal Python data from the AiiDA Data node for the + args and kwargs of the function. + + Args: + tdata (Dict[str, Any]): The input data dictionary. + + Returns: + Dict[str, Any]: The processed data dictionary. + """ + input_kwargs = cls.get_function_kwargs(tdata) + + for name in input_kwargs: + if name in tdata["properties"]: + value = tdata["properties"][name]["value"] + if isinstance(value, orm.Data): + value = value.value + elif value is not None and value != {}: + raise ValueError(f"There something wrong with the input {name}") + tdata["properties"][name]["value"] = value diff --git a/aiida_workgraph/utils/__init__.py b/aiida_workgraph/utils/__init__.py index b5b18442..5752c03d 100644 --- a/aiida_workgraph/utils/__init__.py +++ b/aiida_workgraph/utils/__init__.py @@ -218,38 +218,6 @@ def get_dict_from_builder(builder: Any) -> Dict: return builder -def get_pythonjob_data(tdata: Dict[str, Any]) -> Dict[str, Any]: - """ - Process the task data dictionary for a PythonJob. - It load the orignal Python data from the AiiDA Data node for the - args and kwargs of the function. - - Args: - tdata (Dict[str, Any]): The input data dictionary. - - Returns: - Dict[str, Any]: The processed data dictionary. - """ - for name in tdata["metadata"]["args"]: - if tdata["properties"][name]["value"] is None: - continue - if name in tdata["properties"]: - tdata["properties"][name]["value"] = tdata["properties"][name][ - "value" - ].value - for name in tdata["metadata"]["kwargs"]: - # all the kwargs are after computer is the input for the PythonJob, should be AiiDA Data node - if tdata["properties"][name]["value"] is None: - continue - if name == "computer": - break - if name in tdata["properties"]: - tdata["properties"][name]["value"] = tdata["properties"][name][ - "value" - ].value - return tdata - - def serialize_workgraph_data(wgdata: Dict[str, Any]) -> Dict[str, Any]: from aiida.orm.utils.serialize import serialize @@ -270,8 +238,6 @@ def get_workgraph_data(process: Union[int, orm.Node]) -> Optional[Dict[str, Any] return for name, task in wgdata["tasks"].items(): wgdata["tasks"][name] = deserialize_unsafe(task) - if wgdata["tasks"][name]["metadata"]["node_type"].upper() == "PYTHONJOB": - get_pythonjob_data(wgdata["tasks"][name]) wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"]) return wgdata @@ -388,33 +354,14 @@ def serialize_properties(wgdata): save it to the node.base.extras. yaml can not handle the function defined in a scope, e.g., local function in another function. So, if a function is used as input, we needt to serialize the function. - - For PythonJob, serialize the function inputs.""" - from aiida_workgraph.orm.serializer import general_serializer + """ from aiida_workgraph.orm.function_data import PickledLocalFunction import inspect for _, task in wgdata["tasks"].items(): - if task["metadata"]["node_type"].upper() == "PYTHONJOB": - # get the names kwargs for the PythonJob, which are the inputs before _wait - input_kwargs = [] - for input in task["inputs"]: - if input["name"] == "_wait": - break - input_kwargs.append(input["name"]) - for name in input_kwargs: - prop = task["properties"][name] - # if value is not None, not {} - if not ( - prop["value"] is None - or isinstance(prop["value"], dict) - and prop["value"] == {} - ): - prop["value"] = general_serializer(prop["value"]) - else: - for _, prop in task["properties"].items(): - if inspect.isfunction(prop["value"]): - prop["value"] = PickledLocalFunction(prop["value"]).store() + for _, prop in task["properties"].items(): + if inspect.isfunction(prop["value"]): + prop["value"] = PickledLocalFunction(prop["value"]).store() def generate_bash_to_create_python_env( diff --git a/docs/source/built-in/pythonjob.ipynb b/docs/source/built-in/pythonjob.ipynb index 9d86bdd3..fbcef4ab 100644 --- a/docs/source/built-in/pythonjob.ipynb +++ b/docs/source/built-in/pythonjob.ipynb @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "id": "c6b83fb5", "metadata": {}, "outputs": [ @@ -33,7 +33,7 @@ "Profile" ] }, - "execution_count": 3, + "execution_count": 1, "metadata": {}, "output_type": "execute_result" } @@ -2368,7 +2368,61 @@ "id": "fe376995", "metadata": {}, "source": [ - "We can see that the `result.txt` file is retrieved from the remote computer and stored in the local repository." + "We can see that the `result.txt` file is retrieved from the remote computer and stored in the local repository.\n", + "\n", + "## Exit Code\n", + "\n", + "The `PythonJob` task includes a built-in output socket, `exit_code`, which serves as a mechanism for error handling and status reporting during task execution. This `exit_code` is an integer value where `0` indicates a successful completion, and any non-zero value signals that an error occurred.\n", + "\n", + "### How it Works:\n", + "When the function returns a dictionary with an `exit_code` key, the system automatically parses and uses this code to indicate the task's status. In the case of an error, the non-zero `exit_code` value helps identify the specific problem.\n", + "\n", + "\n", + "### Benefits of `exit_code`:\n", + "\n", + "1. **Error Reporting:** \n", + " If the task encounters an error, the `exit_code` can communicate the reason. This is helpful during process inspection to determine why a task failed.\n", + "\n", + "2. **Error Handling and Recovery:** \n", + " You can utilize `exit_code` to add specific error handlers for particular exit codes. This allows you to modify the task's parameters and restart it.\n", + "\n", + "\n", + "Below is an example Python function that uses `exit_code` to handle potential errors:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a96cbbcb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WorkGraph process created, PK: 146751\n", + "exit status: 1\n", + "exit message: Sum is negative\n" + ] + } + ], + "source": [ + "from aiida_workgraph import WorkGraph, task\n", + "\n", + "@task.pythonjob(outputs=[{\"name\": \"sum\"}])\n", + "def add(x: int, y: int) -> int:\n", + " sum = x + y\n", + " if sum < 0:\n", + " exit_code = {\"status\": 1, \"message\": \"Sum is negative\"}\n", + " return {\"sum\": sum, \"exit_code\": exit_code}\n", + " return {\"sum\": sum}\n", + "\n", + "wg = WorkGraph(\"test_PythonJob\")\n", + "wg.add_task(add, name=\"add\", x=1, y=-2)\n", + "wg.submit(wait=True)\n", + "\n", + "print(\"exit status: \", wg.tasks[\"add\"].node.exit_status)\n", + "print(\"exit message: \", wg.tasks[\"add\"].node.exit_message)" ] }, { @@ -2376,6 +2430,8 @@ "id": "8d4d935b", "metadata": {}, "source": [ + "In this example, the task failed with `exit_code = 1` due to the condition `Sum is negative`, which is also reflected in the state message.\n", + "\n", "## Define your data serializer\n", "Workgraph search data serializer from the `aiida.data` entry point by the module name and class name (e.g., `ase.atoms.Atoms`). \n", "\n", diff --git a/pyproject.toml b/pyproject.toml index bd3abf49..d6963fb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,6 +122,7 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph" "workgraph.test_greater" = "aiida_workgraph.tasks.test:TestGreater" "workgraph.test_sum_diff" = "aiida_workgraph.tasks.test:TestSumDiff" "workgraph.test_arithmetic_multiply_add" = "aiida_workgraph.tasks.test:TestArithmeticMultiplyAdd" +"workgraph.pythonjob" = "aiida_workgraph.tasks.pythonjob:PythonJob" [project.entry-points."aiida_workgraph.property"] "workgraph.any" = "aiida_workgraph.properties.builtins:PropertyAny" diff --git a/tests/test_python.py b/tests/test_python.py index 1043069a..c760666f 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -486,3 +486,54 @@ def add(x: str, y: str) -> str: ) assert wg.tasks["add"].outputs["result"].value.value == "Hello, World!" wg = WorkGraph.load(wg.pk) + + +def test_exit_code(fixture_localhost, python_executable_path): + """Test function with exit code.""" + from numpy import array + + @task.pythonjob(outputs=[{"name": "sum"}]) + def add(x: array, y: array) -> array: + sum = x + y + if (sum < 0).any(): + exit_code = {"status": 410, "message": "Some elements are negative"} + return {"sum": sum, "exit_code": exit_code} + return {"sum": sum} + + def handle_negative_sum(self, task_name: str): + """Handle the failure code 410 of the `add`. + Simply make the inputs positive by taking the absolute value. + """ + self.report("Run error handler: handle_negative_sum.") + # load the task from the WorkGraph engine + task = self.get_task(task_name) + # modify task inputs + task.set({"x": abs(task.inputs["x"].value), "y": abs(task.inputs["y"].value)}) + + self.update_task(task) + + wg = WorkGraph("test_PythonJob") + wg.add_task( + add, + name="add1", + x=array([1, 1]), + y=array([1, -2]), + computer="localhost", + code_label=python_executable_path, + ) + # register error handler + wg.attach_error_handler( + handle_negative_sum, + name="handle_negative_sum", + tasks={"add1": {"exit_codes": [410], "max_retries": 5}}, + ) + wg.run() + # the first task should have exit status 410 + assert wg.process.base.links.get_outgoing().all()[0].node.exit_status == 410 + assert ( + wg.process.base.links.get_outgoing().all()[0].node.exit_message + == "Some elements are negative" + ) + # the final task should have exit status 0 + assert wg.tasks["add1"].node.exit_status == 0 + assert (wg.tasks["add1"].outputs["sum"].value.value == array([2, 3])).all()