diff --git a/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py b/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py index 1d97d28df2..3c23ad69ef 100644 --- a/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py +++ b/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py @@ -330,25 +330,14 @@ def __init__(self, spec: GeneralResultSpec, ctx: ResultPageConverterContext) -> # columns returned by the query we have to add columns from dimension # records that are not returned by the query. These columns belong to # either cached or skypix dimensions. - query_result_columns = set(spec.get_result_columns()) - output_columns = spec.get_all_result_columns() + columns = spec.get_result_columns() universe = spec.dimensions.universe self.converters: list[_GeneralColumnConverter] = [] - for column in output_columns: + self.record_converters: dict[DimensionElement, _DimensionRecordRowConverter] = {} + for column in columns: column_name = qt.ColumnSet.get_qualified_name(column.logical_table, column.field) converter: _GeneralColumnConverter - if column not in query_result_columns and column.field is not None: - # This must be a field from a cached dimension record or - # skypix record. - assert isinstance(column.logical_table, str), "Do not expect AnyDatasetType here" - element = universe[column.logical_table] - if isinstance(element, SkyPixDimension): - converter = _SkypixRecordGeneralColumnConverter(element, column.field) - else: - converter = _CachedRecordGeneralColumnConverter( - element, column.field, ctx.dimension_record_cache - ) - elif column.field == TimespanDatabaseRepresentation.NAME: + if column.field == TimespanDatabaseRepresentation.NAME: converter = _TimespanGeneralColumnConverter(column_name, ctx.db) elif column.field == "ingest_date": converter = _TimestampGeneralColumnConverter(column_name) @@ -356,9 +345,29 @@ def __init__(self, spec: GeneralResultSpec, ctx: ResultPageConverterContext) -> converter = _DefaultGeneralColumnConverter(column_name) self.converters.append(converter) + if spec.include_dimension_records: + universe = self.spec.dimensions.universe + for element_name in self.spec.dimensions.elements: + element = universe[element_name] + if isinstance(element, SkyPixDimension): + self.record_converters[element] = _SkypixDimensionRecordRowConverter(element) + elif element.is_cached: + self.record_converters[element] = _CachedDimensionRecordRowConverter( + element, ctx.dimension_record_cache + ) + def convert(self, raw_rows: Iterable[sqlalchemy.Row]) -> GeneralResultPage: - rows = [tuple(cvt.convert(row) for cvt in self.converters) for row in raw_rows] - return GeneralResultPage(spec=self.spec, rows=rows) + rows = [] + dimension_records = None + if self.spec.include_dimension_records: + dimension_records = {element: DimensionRecordSet(element) for element in self.record_converters} + for row in raw_rows: + rows.append(tuple(cvt.convert(row) for cvt in self.converters)) + if dimension_records: + for element, converter in self.record_converters.items(): + dimension_records[element].add(converter.convert(row)) + + return GeneralResultPage(spec=self.spec, rows=rows, dimension_records=dimension_records) class _GeneralColumnConverter: @@ -440,47 +449,3 @@ def __init__(self, name: str, db: Database): def convert(self, row: sqlalchemy.Row) -> Any: timespan = self.timespan_class.extract(row._mapping, self.name) return timespan - - -class _CachedRecordGeneralColumnConverter(_GeneralColumnConverter): - """Helper for converting result row into a field value for cached - dimension records. - - Parameters - ---------- - element : `DimensionElement` - Dimension element, must be of cached type. - field : `str` - Name of the field to extract from the dimension record. - cache : `DimensionRecordCache` - Cache for dimension records. - """ - - def __init__(self, element: DimensionElement, field: str, cache: DimensionRecordCache) -> None: - self._record_converter = _CachedDimensionRecordRowConverter(element, cache) - self._field = field - - def convert(self, row: sqlalchemy.Row) -> Any: - record = self._record_converter.convert(row) - return getattr(record, self._field) - - -class _SkypixRecordGeneralColumnConverter(_GeneralColumnConverter): - """Helper for converting result row into a field value for skypix - dimension records. - - Parameters - ---------- - element : `SkyPixDimension` - Dimension element. - field : `str` - Name of the field to extract from the dimension record. - """ - - def __init__(self, element: SkyPixDimension, field: str) -> None: - self._record_converter = _SkypixDimensionRecordRowConverter(element) - self._field = field - - def convert(self, row: sqlalchemy.Row) -> Any: - record = self._record_converter.convert(row) - return getattr(record, self._field) diff --git a/python/lsst/daf/butler/queries/_general_query_results.py b/python/lsst/daf/butler/queries/_general_query_results.py index adc4a94857..021d3df8bd 100644 --- a/python/lsst/daf/butler/queries/_general_query_results.py +++ b/python/lsst/daf/butler/queries/_general_query_results.py @@ -35,7 +35,7 @@ from .._dataset_ref import DatasetRef from .._dataset_type import DatasetType -from ..dimensions import DataCoordinate, DimensionElement, DimensionGroup, DimensionRecord +from ..dimensions import DataCoordinate, DimensionElement, DimensionGroup, DimensionRecord, DimensionRecordSet from ._base import QueryResultsBase from .driver import QueryDriver from .result_specs import GeneralResultSpec @@ -99,9 +99,13 @@ def __iter__(self) -> Iterator[dict[str, Any]]: fields (separated from dataset type name by dot). """ for page in self._driver.execute(self._spec, self._tree): - columns = tuple(str(column) for column in page.spec.get_all_result_columns()) + columns = tuple(str(column) for column in page.spec.get_result_columns()) for row in page.rows: - yield dict(zip(columns, row, strict=True)) + result = dict(zip(columns, row, strict=True)) + if page.dimension_records: + records = self._get_cached_dimension_records(result, page.dimension_records) + self._add_dimension_records(result, records) + yield result def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTuple]: """Iterate over result rows and return data coordinate, and dataset @@ -124,13 +128,21 @@ def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTupl id_key = f"{dataset_type.name}.dataset_id" run_key = f"{dataset_type.name}.run" dataset_keys.append((dataset_type, dimensions, id_key, run_key)) - for row in self: - data_coordinate = self._make_data_id(row, all_dimensions) - refs = [] - for dataset_type, dimensions, id_key, run_key in dataset_keys: - data_id = data_coordinate.subset(dimensions) - refs.append(DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key])) - yield GeneralResultTuple(data_id=data_coordinate, refs=refs, raw_row=row) + for page in self._driver.execute(self._spec, self._tree): + columns = tuple(str(column) for column in page.spec.get_result_columns()) + for page_row in page.rows: + row = dict(zip(columns, page_row, strict=True)) + if page.dimension_records: + cached_records = self._get_cached_dimension_records(row, page.dimension_records) + self._add_dimension_records(row, cached_records) + else: + cached_records = {} + data_coordinate = self._make_data_id(row, all_dimensions, cached_records) + refs = [] + for dataset_type, dimensions, id_key, run_key in dataset_keys: + data_id = data_coordinate.subset(dimensions) + refs.append(DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key])) + yield GeneralResultTuple(data_id=data_coordinate, refs=refs, raw_row=row) @property def dimensions(self) -> DimensionGroup: @@ -162,14 +174,22 @@ def _get_datasets(self) -> frozenset[str]: # Docstring inherited. return frozenset(self._spec.dataset_fields) - def _make_data_id(self, row: dict[str, Any], dimensions: DimensionGroup) -> DataCoordinate: + def _make_data_id( + self, + row: dict[str, Any], + dimensions: DimensionGroup, + cached_row_records: dict[DimensionElement, DimensionRecord], + ) -> DataCoordinate: values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied)) data_coordinate = DataCoordinate.from_full_values(dimensions, values) if self.has_dimension_records: - records = { - name: self._make_dimension_record(row, dimensions.universe[name]) - for name in dimensions.elements - } + records = {} + for name in dimensions.elements: + element = dimensions.universe[name] + record = cached_row_records.get(element) + if record is None: + record = self._make_dimension_record(row, dimensions.universe[name]) + records[name] = record data_coordinate = data_coordinate.expanded(records) return data_coordinate @@ -185,3 +205,22 @@ def _make_dimension_record(self, row: dict[str, Any], element: DimensionElement) d = {k: row[v] for k, v in column_map} record_cls = element.RecordClass return record_cls(**d) + + def _get_cached_dimension_records( + self, row: dict[str, Any], dimension_records: dict[DimensionElement, DimensionRecordSet] + ) -> dict[DimensionElement, DimensionRecord]: + """Find cached dimension records matching this row.""" + records = {} + for element, element_records in dimension_records.items(): + required_values = tuple(row[key] for key in element.required.names) + records[element] = element_records.find_with_required_values(required_values) + return records + + def _add_dimension_records( + self, row: dict[str, Any], records: dict[DimensionElement, DimensionRecord] + ) -> None: + """Extend row with the fields from cached dimension records.""" + for element, record in records.items(): + for name, value in record.toDict().items(): + if name not in element.schema.required.names: + row[f"{element.name}.{name}"] = value diff --git a/python/lsst/daf/butler/queries/driver.py b/python/lsst/daf/butler/queries/driver.py index d1e23ef813..22703c73c1 100644 --- a/python/lsst/daf/butler/queries/driver.py +++ b/python/lsst/daf/butler/queries/driver.py @@ -47,6 +47,7 @@ from ..dimensions import ( DataCoordinate, DataIdValue, + DimensionElement, DimensionGroup, DimensionRecord, DimensionRecordSet, @@ -117,9 +118,13 @@ class GeneralResultPage: spec: GeneralResultSpec # Raw tabular data, with columns in the same order as - # spec.get_all_result_columns(). + # spec.get_result_columns(). rows: list[tuple[Any, ...]] + # This map contains dimension records for cached and skypix elements, + # and only when spec.include_dimension_records is True. + dimension_records: dict[DimensionElement, DimensionRecordSet] | None + ResultPage: TypeAlias = Union[ DataCoordinateResultPage, DimensionRecordResultPage, DatasetRefResultPage, GeneralResultPage diff --git a/python/lsst/daf/butler/queries/result_specs.py b/python/lsst/daf/butler/queries/result_specs.py index 6e3b1360e2..baf131d865 100644 --- a/python/lsst/daf/butler/queries/result_specs.py +++ b/python/lsst/daf/butler/queries/result_specs.py @@ -248,33 +248,12 @@ def get_result_columns(self) -> ColumnSet: result.dataset_fields[dataset_type].update(fields_for_dataset) if self.include_dimension_records: # This only adds record fields for non-cached and non-skypix - # elements, this is what we want when generating query. We could - # potentially add those too but it may make queries slower, so - # instead we query cached dimension records separately and add them - # to the result page in the page converter. + # elements, this is what we want when generating query. When + # `include_dimension_records` is True, dimension records for cached + # and skypix elements are added to result pages by page converter. _add_dimension_records_to_column_set(self.dimensions, result) return result - def get_all_result_columns(self) -> ColumnSet: - """Return all columns that have to appear in the result. This includes - columns for all dimension records for all dimensions if - ``include_dimension_records`` is `True`. - - Returns - ------- - columns : `ColumnSet` - Full column set. - """ - dimensions = self.dimensions - result = self.get_result_columns() - if self.include_dimension_records: - for element_name in dimensions.elements: - element = dimensions.universe[element_name] - # Non-cached dimensions are already there, but it does not harm - # to add them again. - result.dimension_fields[element_name].update(element.schema.remainder.names) - return result - @pydantic.model_validator(mode="after") def _validate(self) -> GeneralResultSpec: if self.find_first and len(self.dataset_fields) != 1: diff --git a/python/lsst/daf/butler/remote_butler/_query_driver.py b/python/lsst/daf/butler/remote_butler/_query_driver.py index 05949243d3..47cec7c248 100644 --- a/python/lsst/daf/butler/remote_butler/_query_driver.py +++ b/python/lsst/daf/butler/remote_butler/_query_driver.py @@ -40,7 +40,14 @@ from ...butler import Butler from .._dataset_type import DatasetType -from ..dimensions import DataCoordinate, DataIdValue, DimensionGroup, DimensionRecord, DimensionUniverse +from ..dimensions import ( + DataCoordinate, + DataIdValue, + DimensionGroup, + DimensionRecord, + DimensionRecordSet, + DimensionUniverse, +) from ..queries.driver import ( DataCoordinateResultPage, DatasetRefResultPage, @@ -257,20 +264,15 @@ def _convert_query_result_page( def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel) -> GeneralResultPage: """Convert GeneralResultModel to a general result page.""" - columns = spec.get_all_result_columns() - # Verify that column list that we received from server matches local - # expectations (mismatch could result from different versions). Older - # server may not know about `model.columns` in that case it will be empty. - # If `model.columns` is empty then `zip(strict=True)` below will fail if - # column count is different (column names are not checked in that case). - if model.columns: - expected_column_names = [str(column) for column in columns] - if expected_column_names != model.columns: + if spec.include_dimension_records: + # dimension_records must not be None when `include_dimension_records` + # is True, but it will be None if remote server was not upgraded. + if model.dimension_records is None: raise ValueError( - "Inconsistent columns in general result -- " - f"server columns: {model.columns}, expected: {expected_column_names}" + "Missing dimension records in general result -- " "it is likely that server needs an upgrade." ) + columns = spec.get_result_columns() serializers = [ columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns ] @@ -278,4 +280,15 @@ def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel) tuple(serializer.deserialize(value) for value, serializer in zip(row, serializers, strict=True)) for row in model.rows ] - return GeneralResultPage(spec=spec, rows=rows) + + universe = spec.dimensions.universe + dimension_records = None + if model.dimension_records is not None: + dimension_records = {} + for name, records in model.dimension_records.items(): + element = universe[name] + dimension_records[element] = DimensionRecordSet( + element, (DimensionRecord.from_simple(r, universe) for r in records) + ) + + return GeneralResultPage(spec=spec, rows=rows, dimension_records=dimension_records) diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py index 7357699f19..d5a22b0d17 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py @@ -79,7 +79,7 @@ def convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResult def _convert_general_result(page: GeneralResultPage) -> GeneralResultModel: """Convert GeneralResultPage to a serializable model.""" - columns = page.spec.get_all_result_columns() + columns = page.spec.get_result_columns() serializers = [ columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns ] @@ -87,4 +87,10 @@ def _convert_general_result(page: GeneralResultPage) -> GeneralResultModel: tuple(serializer.serialize(value) for value, serializer in zip(row, serializers, strict=True)) for row in page.rows ] - return GeneralResultModel(rows=rows, columns=[str(column) for column in columns]) + dimension_records = None + if page.dimension_records is not None: + dimension_records = { + element.name: [record.to_simple() for record in records] + for element, records in page.dimension_records.items() + } + return GeneralResultModel(rows=rows, dimension_records=dimension_records) diff --git a/python/lsst/daf/butler/remote_butler/server_models.py b/python/lsst/daf/butler/remote_butler/server_models.py index 6a92d4a4ad..4f76e605bc 100644 --- a/python/lsst/daf/butler/remote_butler/server_models.py +++ b/python/lsst/daf/butler/remote_butler/server_models.py @@ -313,9 +313,10 @@ class GeneralResultModel(pydantic.BaseModel): type: Literal["general"] = "general" rows: list[tuple[Any, ...]] - # List of column names, default is used for compatibility with older + # Dimension records indexed by element name, only cached and skypix + # elements are included. Default is used for compatibility with older # servers that do not set this field. - columns: list[str] = pydantic.Field(default_factory=list) + dimension_records: dict[str, list[SerializedDimensionRecord]] | None = None class QueryErrorResultModel(pydantic.BaseModel):