Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor workflow_state xtrig pre-8.3.0-back-compat #51

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 58 additions & 35 deletions cylc/flow/dbstatecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import os
import sqlite3
import sys
from typing import Optional, List
from textwrap import dedent
from typing import Dict, Iterable, Optional, List, Union

from cylc.flow import LOG
from cylc.flow.exceptions import InputError
from cylc.flow.cycling.util import add_offset
from cylc.flow.cycling.integer import (
Expand All @@ -33,13 +33,20 @@
from cylc.flow.rundb import CylcWorkflowDAO
from cylc.flow.task_outputs import (
TASK_OUTPUT_SUCCEEDED,
TASK_OUTPUT_FAILED
TASK_OUTPUT_FAILED,
TASK_OUTPUT_FINISHED,
)
from cylc.flow.util import deserialise_set
from metomi.isodatetime.parsers import TimePointParser
from metomi.isodatetime.exceptions import ISO8601SyntaxError


output_fallback_msg = (
"Unable to filter by task output label for tasks run in Cylc versions "
"between 8.0.0-8.3.0. Falling back to filtering by task message instead."
)


class CylcWorkflowDBChecker:
"""Object for querying task status or outputs from a workflow database.

Expand Down Expand Up @@ -70,12 +77,12 @@ def __init__(self, rund, workflow, db_path=None):
# Get workflow point format.
try:
self.db_point_fmt = self._get_db_point_format()
self.back_compat_mode = False
self.c7_back_compat_mode = False
except sqlite3.OperationalError as exc:
# BACK COMPAT: Cylc 7 DB (see method below).
try:
self.db_point_fmt = self._get_db_point_format_compat()
self.back_compat_mode = True
self.c7_back_compat_mode = True
except sqlite3.OperationalError:
raise exc # original error

Expand Down Expand Up @@ -194,7 +201,7 @@ def workflow_state_query(
]
For an output query:
[
[name, cycle, "[out1: msg1, out2: msg2, ...]"],
[name, cycle, "{out1: msg1, out2: msg2, ...}"],
...
]
"""
Expand All @@ -208,16 +215,16 @@ def workflow_state_query(
target_table = CylcWorkflowDAO.TABLE_TASK_STATES
mask = "name, cycle, status"

if not self.back_compat_mode:
if not self.c7_back_compat_mode:
# Cylc 8 DBs only
mask += ", flow_nums"

stmt = dedent(rf'''
stmt = rf'''
SELECT
{mask}
FROM
{target_table}
''') # nosec
''' # nosec
# * mask is hardcoded
# * target_table is a code constant

Expand All @@ -241,20 +248,20 @@ def workflow_state_query(
stmt_wheres.append("cycle==?")
stmt_args.append(cycle)

if selector is not None and not (is_output or is_message):
if (
selector is not None
and target_table == CylcWorkflowDAO.TABLE_TASK_STATES
):
# Can select by status in the DB but not outputs.
stmt_wheres.append("status==?")
stmt_args.append(selector)

if stmt_wheres:
stmt += "WHERE\n " + (" AND ").join(stmt_wheres)

if not (is_output or is_message):
if target_table == CylcWorkflowDAO.TABLE_TASK_STATES:
# (outputs table doesn't record submit number)
stmt += dedent("""
ORDER BY
submit_num
""")
stmt += r"ORDER BY submit_num"

# Query the DB and drop incompatible rows.
db_res = []
Expand All @@ -264,7 +271,7 @@ def workflow_state_query(
if row[2] is None:
# status can be None in Cylc 7 DBs
continue
if not self.back_compat_mode:
if not self.c7_back_compat_mode:
flow_nums = deserialise_set(row[3])
if flow_num is not None and flow_num not in flow_nums:
# skip result, wrong flow
Expand All @@ -274,34 +281,50 @@ def workflow_state_query(
res.append(fstr)
db_res.append(res)

if not (is_output or is_message):
if target_table == CylcWorkflowDAO.TABLE_TASK_STATES:
return db_res

warn_output_fallback = is_output
results = []
for row in db_res:
outputs_map = json.loads(row[2])
if is_message:
# task message
try:
outputs = list(outputs_map.values())
except AttributeError:
# Cylc 8 pre 8.3.0 back-compat: list of output messages
outputs = list(outputs_map)
outputs: Union[Dict[str, str], List[str]] = json.loads(row[2])
if isinstance(outputs, dict):
messages: Iterable[str] = outputs.values()
else:
# task output
outputs = list(outputs_map)
# Cylc 8 pre 8.3.0 back-compat: list of output messages
messages = outputs
if warn_output_fallback:
LOG.warning(output_fallback_msg)
warn_output_fallback = False
Comment on lines -282 to +298
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clearer distinction between messages and outputs, added warning for 8.0.0-8.3.0 tasks


if (
selector is None or
selector in outputs or
(
selector in ("finished", "finish")
and (
TASK_OUTPUT_SUCCEEDED in outputs
or TASK_OUTPUT_FAILED in outputs
)
)
(is_message and selector in messages) or
(is_output and self._selector_in_outputs(selector, outputs))
):
results.append(row[:2] + [str(outputs)] + row[3:])

return results

@staticmethod
def _selector_in_outputs(selector: str, outputs: Iterable[str]) -> bool:
"""Check if a selector, including "finished", is in the outputs.

Examples:
>>> this = CylcWorkflowDBChecker._selector_in_outputs
>>> this('moop', ['started', 'moop'])
True
>>> this('moop', ['started'])
False
>>> this('finished', ['succeeded'])
True
>>> this('finish', ['failed'])
True
"""
return selector in outputs or (
selector in (TASK_OUTPUT_FINISHED, "finish")
and (
TASK_OUTPUT_SUCCEEDED in outputs
or TASK_OUTPUT_FAILED in outputs
)
)
9 changes: 5 additions & 4 deletions cylc/flow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,13 @@ class XtriggerConfigError(WorkflowConfigError):

"""

def __init__(self, label: str, message: str):
self.label: str = label
self.message: str = message
def __init__(self, label: str, func: str, message: Union[str, Exception]):
self.label = label
self.func = func
self.message = message

def __str__(self) -> str:
return f'[@{self.label}] {self.message}'
return f'[@{self.label}] {self.func}\n{self.message}'


class ClientError(CylcError):
Expand Down
18 changes: 12 additions & 6 deletions cylc/flow/rundb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import traceback
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Set,
Expand All @@ -38,6 +39,7 @@

if TYPE_CHECKING:
from pathlib import Path
from cylc.flow.flow_mgr import FlowNums


@dataclass
Expand Down Expand Up @@ -806,10 +808,12 @@ def select_latest_flow_nums(self):
flow_nums_str = list(self.connect().execute(stmt))[0][0]
return deserialise_set(flow_nums_str)

def select_task_outputs(self, name, point):
def select_task_outputs(
self, name: str, point: str
) -> 'Dict[str, FlowNums]':
"""Select task outputs for each flow.

Return: {outputs_list: flow_nums_set}
Return: {outputs_dict_str: flow_nums_set}

"""
stmt = rf'''
Expand All @@ -820,10 +824,12 @@ def select_task_outputs(self, name, point):
WHERE
name==? AND cycle==?
''' # nosec (table name is code constant)
ret = {}
for flow_nums, outputs in self.connect().execute(stmt, (name, point,)):
ret[outputs] = deserialise_set(flow_nums)
return ret
return {
outputs: deserialise_set(flow_nums)
for flow_nums, outputs in self.connect().execute(
stmt, (name, point,)
)
}

def select_xtriggers_for_restart(self, callback):
stmt = rf'''
Expand Down
Loading
Loading