diff --git a/lib/galaxy/model/__init__.py b/lib/galaxy/model/__init__.py index 074b3f2df062..269522a69b18 100644 --- a/lib/galaxy/model/__init__.py +++ b/lib/galaxy/model/__init__.py @@ -6845,9 +6845,10 @@ def dataset_elements_and_identifiers(self, identifiers=None): def first_dataset_element(self) -> Optional["DatasetCollectionElement"]: for element in self.elements: if element.is_collection: - first_element = element.child_collection.first_dataset_element - if first_element: - return first_element + if element.child_collection: + first_element = element.child_collection.first_dataset_element + if first_element: + return first_element else: return element return None @@ -7019,7 +7020,7 @@ class HistoryDatasetCollectionAssociation( create_time: Mapped[datetime] = mapped_column(default=now, nullable=True) update_time: Mapped[datetime] = mapped_column(default=now, onupdate=now, index=True, nullable=True) - collection = relationship("DatasetCollection") + collection: Mapped["DatasetCollection"] = relationship("DatasetCollection") history: Mapped[Optional["History"]] = relationship(back_populates="dataset_collections") copied_from_history_dataset_collection_association = relationship( @@ -7437,18 +7438,18 @@ class DatasetCollectionElement(Base, Dictifiable, Serializable): element_index: Mapped[Optional[int]] element_identifier: Mapped[Optional[str]] = mapped_column(Unicode(255)) - hda = relationship( + hda: Mapped[Optional["HistoryDatasetAssociation"]] = relationship( "HistoryDatasetAssociation", primaryjoin=(lambda: DatasetCollectionElement.hda_id == HistoryDatasetAssociation.id), ) - ldda = relationship( + ldda: Mapped[Optional["LibraryDatasetDatasetAssociation"]] = relationship( "LibraryDatasetDatasetAssociation", primaryjoin=(lambda: DatasetCollectionElement.ldda_id == LibraryDatasetDatasetAssociation.id), ) - child_collection = relationship( + child_collection: Mapped[Optional["DatasetCollection"]] = relationship( "DatasetCollection", primaryjoin=(lambda: DatasetCollectionElement.child_collection_id == DatasetCollection.id) ) - collection = relationship( + collection: Mapped[DatasetCollection] = relationship( "DatasetCollection", primaryjoin=(lambda: DatasetCollection.id == DatasetCollectionElement.dataset_collection_id), back_populates="elements", diff --git a/lib/galaxy/tool_util/parser/interface.py b/lib/galaxy/tool_util/parser/interface.py index af72bf4a4825..21db34d203f8 100644 --- a/lib/galaxy/tool_util/parser/interface.py +++ b/lib/galaxy/tool_util/parser/interface.py @@ -37,6 +37,12 @@ ResourceRequirement, ToolRequirements, ) + from galaxy.tool_util.parser.output_objects import ( + ToolOutput, + ToolOutputCollection, + ) + from galaxy.tools import Tool + NOT_IMPLEMENTED_MESSAGE = "Galaxy tool format does not yet support this tool feature." @@ -331,7 +337,9 @@ def parse_provided_metadata_file(self): return "galaxy.json" @abstractmethod - def parse_outputs(self, tool): + def parse_outputs( + self, tool: Optional["Tool"] + ) -> Tuple[Dict[str, "ToolOutput"], Dict[str, "ToolOutputCollection"]]: """Return a pair of output and output collections ordered dictionaries for use by Tool. """ diff --git a/lib/galaxy/tool_util/parser/output_models.py b/lib/galaxy/tool_util/parser/output_models.py index c16d423443a9..d72c653de28c 100644 --- a/lib/galaxy/tool_util/parser/output_models.py +++ b/lib/galaxy/tool_util/parser/output_models.py @@ -8,6 +8,7 @@ from typing import ( List, Optional, + Sequence, Union, ) @@ -105,7 +106,7 @@ class FilePatternDatasetCollectionDescription(DatasetCollectionDescription): ToolOutput = Annotated[ToolOutputT, Field(discriminator="type")] -def from_tool_source(tool_source: ToolSource) -> List[ToolOutput]: +def from_tool_source(tool_source: ToolSource) -> Sequence[ToolOutput]: tool_outputs, tool_output_collections = tool_source.parse_outputs(None) outputs = [] for tool_output in tool_outputs.values(): diff --git a/lib/galaxy/tool_util/parser/output_objects.py b/lib/galaxy/tool_util/parser/output_objects.py index 63148c1fb946..bedff968f89a 100644 --- a/lib/galaxy/tool_util/parser/output_objects.py +++ b/lib/galaxy/tool_util/parser/output_objects.py @@ -281,7 +281,7 @@ def __init__( self.collection = True self.default_format = default_format self.structure = structure - self.outputs: Dict[str, str] = {} + self.outputs: Dict[str, ToolOutput] = {} self.inherit_format = inherit_format self.inherit_metadata = inherit_metadata diff --git a/lib/galaxy/tools/__init__.py b/lib/galaxy/tools/__init__.py index bb62f87e05d5..757483abab97 100644 --- a/lib/galaxy/tools/__init__.py +++ b/lib/galaxy/tools/__init__.py @@ -83,6 +83,10 @@ PageSource, ToolSource, ) +from galaxy.tool_util.parser.output_objects import ( + ToolOutput, + ToolOutputCollection, +) from galaxy.tool_util.parser.util import ( parse_profile_version, parse_tool_version_with_defaults, @@ -847,6 +851,8 @@ def __init__( self.tool_errors = None # Parse XML element containing configuration self.tool_source = tool_source + self.outputs: Dict[str, ToolOutput] = {} + self.output_collections: Dict[str, ToolOutputCollection] = {} self._is_workflow_compatible = None self.__help = None self.__tests: Optional[str] = None diff --git a/lib/galaxy/tools/actions/__init__.py b/lib/galaxy/tools/actions/__init__.py index 841eea988d49..fd5f97686fba 100644 --- a/lib/galaxy/tools/actions/__init__.py +++ b/lib/galaxy/tools/actions/__init__.py @@ -9,6 +9,7 @@ cast, Dict, List, + MutableMapping, Optional, Set, Tuple, @@ -533,7 +534,7 @@ def handle_output(name, output, hidden=None): output, wrapped_params.params, inp_data, - inp_dataset_collections, + input_collections, input_ext, python_template_version=tool.python_template_version, execution_cache=execution_cache, @@ -1156,7 +1157,7 @@ def determine_output_format( output: "ToolOutput", parameter_context, input_datasets, - input_dataset_collections, + input_dataset_collections: MutableMapping[str, model.HistoryDatasetCollectionAssociation], random_input_ext, python_template_version="3", execution_cache=None, @@ -1198,7 +1199,7 @@ def determine_output_format( if collection_name in input_dataset_collections: try: - input_collection = input_dataset_collections[collection_name][0][0] + input_collection = input_dataset_collections[collection_name] input_collection_collection = input_collection.collection if element_index is None: # just pick the first HDA diff --git a/lib/galaxy/tools/evaluation.py b/lib/galaxy/tools/evaluation.py index 582bd65be06e..1d1a11914fc1 100644 --- a/lib/galaxy/tools/evaluation.py +++ b/lib/galaxy/tools/evaluation.py @@ -33,6 +33,7 @@ MinimalToolApp, ) from galaxy.tool_util.data import TabularToolDataTable +from galaxy.tools.actions import determine_output_format from galaxy.tools.parameters import ( visit_input_values, wrapped_json, @@ -64,6 +65,7 @@ safe_makedirs, unicodify, ) +from galaxy.util.path import StrPath from galaxy.util.template import ( fill_template, InputNotFoundSyntaxError, @@ -102,7 +104,7 @@ def __init__(self, *args: object, tool_id: Optional[str], tool_version: str, is_ self.is_latest = is_latest -def global_tool_logs(func, config_file: str, action_str: str, tool: "Tool"): +def global_tool_logs(func, config_file: Optional[StrPath], action_str: str, tool: "Tool"): try: return func() except Exception as e: @@ -130,7 +132,7 @@ class ToolEvaluator: job: model.Job materialize_datasets: bool = True - def __init__(self, app: MinimalToolApp, tool, job, local_working_directory): + def __init__(self, app: MinimalToolApp, tool: "Tool", job, local_working_directory): self.app = app self.job = job self.tool = tool @@ -186,6 +188,9 @@ def set_compute_environment(self, compute_environment: ComputeEnvironment, get_s out_data, output_collections=out_collections, ) + # late update of format_source outputs + self._eval_format_source(job, inp_data, out_data) + self.execute_tool_hooks(inp_data=inp_data, out_data=out_data, incoming=incoming) def execute_tool_hooks(self, inp_data, out_data, incoming): @@ -275,6 +280,23 @@ def _materialize_objects( return undeferred_objects + def _eval_format_source( + self, + job: model.Job, + inp_data: Dict[str, Optional[model.DatasetInstance]], + out_data: Dict[str, model.DatasetInstance], + ): + for output_name, output in out_data.items(): + if ( + (tool_output := self.tool.outputs.get(output_name)) + and (tool_output.format_source or tool_output.change_format) + and output.extension == "expression.json" + ): + input_collections = {jtidca.name: jtidca.dataset_collection for jtidca in job.input_dataset_collections} + ext = determine_output_format(tool_output, self.param_dict, inp_data, input_collections, None) + if ext: + output.extension = ext + def _replaced_deferred_objects( self, inp_data: Dict[str, Optional[model.DatasetInstance]], @@ -364,6 +386,9 @@ def do_walk(inputs, input_values): do_walk(inputs, input_values) def __populate_wrappers(self, param_dict, input_datasets, job_working_directory): + + element_identifier_mapper = ElementIdentifierMapper(input_datasets) + def wrap_input(input_values, input): value = input_values[input.name] if isinstance(input, DataToolParameter) and input.multiple: @@ -380,26 +405,26 @@ def wrap_input(input_values, input): elif isinstance(input, DataToolParameter): dataset = input_values[input.name] - wrapper_kwds = dict( + element_identifier = element_identifier_mapper.identifier(dataset, param_dict) + input_values[input.name] = DatasetFilenameWrapper( + dataset=dataset, datatypes_registry=self.app.datatypes_registry, tool=self.tool, name=input.name, compute_environment=self.compute_environment, + identifier=element_identifier, + formats=input.formats, ) - element_identifier = element_identifier_mapper.identifier(dataset, param_dict) - if element_identifier: - wrapper_kwds["identifier"] = element_identifier - wrapper_kwds["formats"] = input.formats - input_values[input.name] = DatasetFilenameWrapper(dataset, **wrapper_kwds) elif isinstance(input, DataCollectionToolParameter): dataset_collection = value - wrapper_kwds = dict( + wrapper = DatasetCollectionWrapper( + job_working_directory=job_working_directory, + has_collection=dataset_collection, datatypes_registry=self.app.datatypes_registry, compute_environment=self.compute_environment, tool=self.tool, name=input.name, ) - wrapper = DatasetCollectionWrapper(job_working_directory, dataset_collection, **wrapper_kwds) input_values[input.name] = wrapper elif isinstance(input, SelectToolParameter): if input.multiple: @@ -409,14 +434,13 @@ def wrap_input(input_values, input): ) else: input_values[input.name] = InputValueWrapper( - input, value, param_dict, profile=self.tool and self.tool.profile + input, value, param_dict, profile=self.tool and self.tool.profile or None ) # HACK: only wrap if check_values is not false, this deals with external # tools where the inputs don't even get passed through. These # tools (e.g. UCSC) should really be handled in a special way. if self.tool.check_values: - element_identifier_mapper = ElementIdentifierMapper(input_datasets) self.__walk_inputs(self.tool.inputs, param_dict, wrap_input) def __populate_input_dataset_wrappers(self, param_dict, input_datasets): @@ -443,13 +467,13 @@ def __populate_input_dataset_wrappers(self, param_dict, input_datasets): param_dict[name] = wrapper continue if not isinstance(param_dict_value, ToolParameterValueWrapper): - wrapper_kwds = dict( + param_dict[name] = DatasetFilenameWrapper( + dataset=data, datatypes_registry=self.app.datatypes_registry, tool=self.tool, name=name, compute_environment=self.compute_environment, ) - param_dict[name] = DatasetFilenameWrapper(data, **wrapper_kwds) def __populate_output_collection_wrappers(self, param_dict, output_collections, job_working_directory): tool = self.tool @@ -460,14 +484,15 @@ def __populate_output_collection_wrappers(self, param_dict, output_collections, # message = message_template % ( name, tool.output_collections ) # raise AssertionError( message ) - wrapper_kwds = dict( + wrapper = DatasetCollectionWrapper( + job_working_directory=job_working_directory, + has_collection=out_collection, datatypes_registry=self.app.datatypes_registry, compute_environment=self.compute_environment, io_type="output", tool=tool, name=name, ) - wrapper = DatasetCollectionWrapper(job_working_directory, out_collection, **wrapper_kwds) param_dict[name] = wrapper # TODO: Handle nested collections... for element_identifier, output_def in tool.output_collections[name].outputs.items(): @@ -662,6 +687,7 @@ def _build_command_line(self): if interpreter: # TODO: path munging for cluster/dataset server relocatability executable = command_line.split()[0] + assert self.tool.tool_dir tool_dir = os.path.abspath(self.tool.tool_dir) abs_executable = os.path.join(tool_dir, executable) command_line = command_line.replace(executable, f"{interpreter} {shlex.quote(abs_executable)}", 1) diff --git a/lib/galaxy/tools/parameters/basic.py b/lib/galaxy/tools/parameters/basic.py index b578739ccdb9..3befa87ae1e3 100644 --- a/lib/galaxy/tools/parameters/basic.py +++ b/lib/galaxy/tools/parameters/basic.py @@ -1998,6 +1998,7 @@ def do_validate(v): dataset_count += 1 do_validate(v.hda) else: + assert v.child_collection for dataset_instance in v.child_collection.dataset_instances: dataset_count += 1 do_validate(dataset_instance) @@ -2176,33 +2177,39 @@ def from_json(self, value, trans, other_values=None): dataset_matcher_factory = get_dataset_matcher_factory(trans) dataset_matcher = dataset_matcher_factory.dataset_matcher(self, other_values) for v in rval: + value_to_check: Union[ + DatasetInstance, DatasetCollection, DatasetCollectionElement, HistoryDatasetCollectionAssociation + ] = v if isinstance(v, DatasetCollectionElement): if hda := v.hda: - v = hda + value_to_check = hda elif ldda := v.ldda: - v = ldda + value_to_check = ldda elif collection := v.child_collection: - v = collection - elif not v.collection and v.collection.populated_optimized: + value_to_check = collection + elif v.collection and not v.collection.populated_optimized: raise ParameterValueError("the selected collection has not been populated.", self.name) else: raise ParameterValueError("Collection element in unexpected state", self.name) - if isinstance(v, DatasetInstance): - if v.deleted: + if isinstance(value_to_check, DatasetInstance): + if value_to_check.deleted: raise ParameterValueError("the previously selected dataset has been deleted.", self.name) - elif v.dataset and v.dataset.state in [Dataset.states.ERROR, Dataset.states.DISCARDED]: + elif value_to_check.dataset and value_to_check.dataset.state in [ + Dataset.states.ERROR, + Dataset.states.DISCARDED, + ]: raise ParameterValueError( "the previously selected dataset has entered an unusable state", self.name ) - match = dataset_matcher.hda_match(v) + match = dataset_matcher.hda_match(value_to_check) if match and match.implicit_conversion: - v.implicit_conversion = True # type:ignore[union-attr] - elif isinstance(v, HistoryDatasetCollectionAssociation): - if v.deleted: + value_to_check.implicit_conversion = True # type:ignore[attr-defined] + elif isinstance(value_to_check, HistoryDatasetCollectionAssociation): + if value_to_check.deleted: raise ParameterValueError("the previously selected dataset collection has been deleted.", self.name) - v = v.collection - if isinstance(v, DatasetCollection): - if v.elements_deleted: + value_to_check = value_to_check.collection + if isinstance(value_to_check, DatasetCollection): + if value_to_check.elements_deleted: raise ParameterValueError( "the previously selected dataset collection has elements that are deleted.", self.name ) diff --git a/lib/galaxy/tools/wrappers.py b/lib/galaxy/tools/wrappers.py index 12d3d779cafc..239065af74ce 100644 --- a/lib/galaxy/tools/wrappers.py +++ b/lib/galaxy/tools/wrappers.py @@ -517,7 +517,7 @@ def __iter__(self) -> Iterator[Any]: pass def _dataset_wrapper( - self, dataset: Union[DatasetInstance, DatasetCollectionElement], **kwargs: Any + self, dataset: Optional[Union[DatasetInstance, DatasetCollectionElement]], **kwargs: Any ) -> DatasetFilenameWrapper: return DatasetFilenameWrapper(dataset, **kwargs) @@ -647,6 +647,7 @@ def __init__( collection = has_collection.collection self.name = has_collection.name elif isinstance(has_collection, DatasetCollectionElement): + assert has_collection.child_collection collection = has_collection.child_collection self.name = has_collection.element_identifier else: @@ -661,8 +662,9 @@ def __init__( for dataset_collection_element in elements: element_object = dataset_collection_element.element_object element_identifier = dataset_collection_element.element_identifier + assert element_identifier is not None - if dataset_collection_element.is_collection: + if isinstance(element_object, DatasetCollection): element_wrapper: DatasetCollectionElementWrapper = DatasetCollectionWrapper( job_working_directory, dataset_collection_element, **kwargs ) diff --git a/lib/galaxy_test/api/test_workflows.py b/lib/galaxy_test/api/test_workflows.py index 6d14cb49e7df..48f623a19487 100644 --- a/lib/galaxy_test/api/test_workflows.py +++ b/lib/galaxy_test/api/test_workflows.py @@ -2260,6 +2260,38 @@ def test_run_workflow_pick_value_bam_pja(self): assert dataset_details["metadata_bam_index"] assert dataset_details["file_ext"] == "bam" + def test_expression_tool_output_in_format_source(self): + with self.dataset_populator.test_history() as history_id: + self._run_workflow( + """class: GalaxyWorkflow +inputs: + input: + type: data +steps: + skip: + tool_id: cat_data_and_sleep + in: + input1: input + when: $(false) + pick_larger: + tool_id: expression_pick_larger_file + in: + input1: skip/out_file1 + input2: input + format_source: + tool_id: cat_data_and_sleep + in: + input1: pick_larger/larger_file +test_data: + input: + value: 1.fastqsanger.gz + type: File + file_type: fastqsanger.gz +""", + history_id=history_id, + ) + self.dataset_populator.wait_for_history(history_id=history_id, assert_ok=True) + def test_run_workflow_simple_conditional_step(self): with self.dataset_populator.test_history() as history_id: summary = self._run_workflow( diff --git a/test/functional/tools/expression_pick_larger_file.xml b/test/functional/tools/expression_pick_larger_file.xml index daed2d9175ea..e624ed302f26 100644 --- a/test/functional/tools/expression_pick_larger_file.xml +++ b/test/functional/tools/expression_pick_larger_file.xml @@ -20,7 +20,7 @@ - + diff --git a/test/unit/app/tools/test_actions.py b/test/unit/app/tools/test_actions.py index 36056fbc61d9..a6126cfb8b32 100644 --- a/test/unit/app/tools/test_actions.py +++ b/test/unit/app/tools/test_actions.py @@ -231,7 +231,7 @@ def __assert_output_format_is(expected, output, input_extensions=None, param_con ) c1.elements = [dce1, dce2] - input_collections["hdcai"] = [(hc1, False)] + input_collections["hdcai"] = hc1 actual_format = determine_output_format(output, param_context, inputs, input_collections, last_ext) assert actual_format == expected, f"Actual format {actual_format}, does not match expected {expected}" diff --git a/test/unit/app/tools/test_evaluation.py b/test/unit/app/tools/test_evaluation.py index e571b5fd8898..27e4b4a147b5 100644 --- a/test/unit/app/tools/test_evaluation.py +++ b/test/unit/app/tools/test_evaluation.py @@ -45,7 +45,7 @@ def setUp(self): self.job.history = History() self.job.history.id = 42 self.job.parameters = [JobParameter(name="thresh", value="4")] - self.evaluator = ToolEvaluator(self.app, self.tool, self.job, self.test_directory) + self.evaluator = ToolEvaluator(self.app, self.tool, self.job, self.test_directory) # type: ignore[arg-type] def tearDown(self): self.tear_down_app() diff --git a/test/unit/data/test_dataset_materialization.py b/test/unit/data/test_dataset_materialization.py index be791a047cce..3573da8d3cab 100644 --- a/test/unit/data/test_dataset_materialization.py +++ b/test/unit/data/test_dataset_materialization.py @@ -374,6 +374,7 @@ def _deferred_element_count(dataset_collection: DatasetCollection) -> int: count = 0 for element in dataset_collection.elements: if element.is_collection: + assert element.child_collection count += _deferred_element_count(element.child_collection) else: dataset_instance = element.dataset_instance diff --git a/test/unit/tool_util/test_parsing.py b/test/unit/tool_util/test_parsing.py index 21ec92387694..42e5bbfd1b1b 100644 --- a/test/unit/tool_util/test_parsing.py +++ b/test/unit/tool_util/test_parsing.py @@ -5,8 +5,8 @@ from math import isinf from typing import ( cast, - List, Optional, + Sequence, Type, TypeVar, ) @@ -262,7 +262,7 @@ def _tool_source(self): return self._get_tool_source() @property - def _output_models(self) -> List[ToolOutput]: + def _output_models(self) -> Sequence[ToolOutput]: return from_tool_source(self._tool_source) def _get_tool_source(self, source_file_name=None, source_contents=None, macro_contents=None):