Skip to content

Commit

Permalink
Add pass_task_error_to_task_output and pass_task_error_to_task_output…
Browse files Browse the repository at this point in the history
…_path to TaskExecutionProperties (#54)

* Add pass_task_error_to_task_output to TaskExecutionProperties

- Added pass_task_error_to_task_output attribute to TaskExecutionProperties.
- Implemented error handling to pass detailed error information to task output.

* Corrected code style issues

* Handle deprecated fields with warnings and logging

* Update test_execute_handle to match new expected validation errors
  • Loading branch information
MartinFoka authored May 27, 2024
1 parent 1e0f6fd commit 943be1b
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 42 deletions.
44 changes: 42 additions & 2 deletions frinx/common/worker/task_def.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from typing import Any
from typing import Optional

from pydantic import BaseModel
Expand Down Expand Up @@ -119,8 +120,16 @@ def __init__(self, error_msg: str) -> None:
class TaskExecutionProperties(BaseModel):
exclude_empty_inputs: bool = False
transform_string_to_json_valid: bool = False
pass_worker_input_exception_to_task_output: bool = False
worker_input_exception_task_output_path: str = 'result.error'
pass_task_error_to_task_output: bool = True
pass_task_error_to_task_output_path: str = 'result.error'
pass_worker_input_exception_to_task_output: bool = Field(
default=False,
description='deprecation_flag',
)
worker_input_exception_task_output_path: str = Field(
default='result.error',
description='deprecation_flag',
)

model_config = ConfigDict(
frozen=True,
Expand All @@ -129,3 +138,34 @@ class TaskExecutionProperties(BaseModel):
arbitrary_types_allowed=False,
populate_by_name=False
)

def __init__(self, **data: Any) -> None:
super().__init__(**data)
self._handle_deprecated_fields()

def _handle_deprecated_fields(self) -> None:
import logging
import warnings

logger = logging.getLogger(__name__)

if self.model_fields['pass_worker_input_exception_to_task_output'].description != 'deprecation_flag':
message = (
"The 'pass_worker_input_exception_to_task_output' field is deprecated and should not be used. "
"Use 'pass_task_error_to_task_output' instead."
)
warnings.warn(message, DeprecationWarning)
logger.warning(message)
object.__setattr__(self, 'pass_task_error_to_task_output',
self.pass_worker_input_exception_to_task_output)

if self.model_fields['worker_input_exception_task_output_path'].description != 'deprecation_flag':
message = (
"The 'worker_input_exception_task_output_path' field is deprecated and should not be used. "
"Use 'pass_task_error_to_task_output_path' instead."
)
warnings.warn(message, DeprecationWarning)
logger.warning(message)

object.__setattr__(self, 'pass_task_error_to_task_output_path',
self.worker_input_exception_task_output_path)
28 changes: 15 additions & 13 deletions frinx/common/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,21 @@ def exception_response_handler(self, error: Exception, **kwargs: Any) -> TaskRes
logs=[TaskExecLog(f'{error_name}: {error}')]
)

match error:
case ValidationError():
formatted_error: DictAny = self._validate_exception_format(error)
task_result.logs = [TaskExecLog(f'{error_name}: {formatted_error}')]

if execution_properties.pass_worker_input_exception_to_task_output:
task_result.output = self._parse_exception_output_path_to_dict(
dot_path={
execution_properties.worker_input_exception_task_output_path: formatted_error
}
)

logger.error('%s error occurred: %s \n%s', error_name, error, str(traceback.format_exc()))
if execution_properties.pass_worker_input_exception_to_task_output:
match error:
case ValidationError():
error_info: str | DictAny = self._validate_exception_format(error)
case _:
error_info = str(error)

error_dict = {'error_name': error_name, 'error_info': error_info}
error_dict_with_output_path = self._parse_exception_output_path_to_dict(
dot_path={execution_properties.pass_task_error_to_task_output_path: error_dict})

task_result.output = TaskOutput(**error_dict_with_output_path)

logger.error('%s error occurred: %s \n%s', error_name, error, str(traceback.format_exc()))

return task_result

def execute_wrapper(self, task: RawTaskIO) -> Any:
Expand Down
41 changes: 14 additions & 27 deletions tests/unit_tests/common_test/worker_test/test_task_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,33 +202,20 @@ def test_execute_handle(self) -> None:
)
)

response: DictAny = DictAny({
'status': 'FAILED',
'output': {
'output': {
'error': {
'string_list': {
'type': 'missing',
'message': 'Field required'
},
'and_dict': {
'type': 'missing',
'message': 'Field required'
},
'required_string': {
'type': 'missing',
'message': 'Field required'
}
}
}
},
'logs': [
"ValidationError: {'string_list': {'type': 'missing', 'message': 'Field required'}, "
"'and_dict': {'type': 'missing', 'message': 'Field required'}, 'required_string': "
"{'type': 'missing', 'message': 'Field required'}}"
]
})
expected_result = {'output': {'error': {'error_name': 'ValidationError',
'error_info':
{
'string_list':
{'type': 'missing', 'message': 'Field required'},
'and_dict':
{'type': 'missing', 'message': 'Field required'},
'required_string':
{'type': 'missing', 'message': 'Field required'}
}
}
}
}

worker = MockExecuteProperties()
result = worker.execute_wrapper(task=worker_input_dict)
assert result == response
assert result.get('output') == expected_result

0 comments on commit 943be1b

Please sign in to comment.