Skip to content

Commit

Permalink
Add an option to include dimension records into general query result …
Browse files Browse the repository at this point in the history
…(DM-47980)

The `GeneralQueryResults.iter_tuples` method returned DataIds without
dimension records. In some cases (e.g. for obscore export) it would be
very useful to include records in the same result to avoid querying
them separately. New method `with_dimension_records` is added to the class
to trigger adding fields from all dimension records into returned page.
This will produce many duplicates for some dimensions (e.g. `instrument`)
but it keeps page structure simple.

This adds one attribute to the `GeneralResultSpec` class, will need some
care with Butler server compatibility.
  • Loading branch information
andy-slac committed Dec 16, 2024
1 parent 812ef82 commit 4bbe699
Showing 8 changed files with 209 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -326,17 +326,35 @@ class GeneralResultPageConverter(ResultPageConverter): # numpydoc ignore=PR01

def __init__(self, spec: GeneralResultSpec, ctx: ResultPageConverterContext) -> None:
self.spec = spec

result_columns = spec.get_result_columns()
# In case `spec.include_dimension_records` is True then in addition to
# columns returned by the query we have to add columns from dimension
# records that are not returned by the query. These columns belong to
# either cached or skypix dimensions.
query_result_columns = set(spec.get_result_columns())
output_columns = spec.get_all_result_columns()
universe = spec.dimensions.universe
self.converters: list[_GeneralColumnConverter] = []
for column in result_columns:
for column in output_columns:
column_name = qt.ColumnSet.get_qualified_name(column.logical_table, column.field)
if column.field == TimespanDatabaseRepresentation.NAME:
self.converters.append(_TimespanGeneralColumnConverter(column_name, ctx.db))
converter: _GeneralColumnConverter
if column not in query_result_columns and column.field is not None:
# This must be a field from a cached dimension record or
# skypix record.
assert isinstance(column.logical_table, str), "Do not expect AnyDatasetType here"
element = universe[column.logical_table]
if isinstance(element, SkyPixDimension):
converter = _SkypixRecordGeneralColumnConverter(element, column.field)

Check warning on line 346 in python/lsst/daf/butler/direct_query_driver/_result_page_converter.py

Codecov / codecov/patch

python/lsst/daf/butler/direct_query_driver/_result_page_converter.py#L346

Added line #L346 was not covered by tests
else:
converter = _CachedRecordGeneralColumnConverter(
element, column.field, ctx.dimension_record_cache
)
elif column.field == TimespanDatabaseRepresentation.NAME:
converter = _TimespanGeneralColumnConverter(column_name, ctx.db)
elif column.field == "ingest_date":
self.converters.append(_TimestampGeneralColumnConverter(column_name))
converter = _TimestampGeneralColumnConverter(column_name)
else:
self.converters.append(_DefaultGeneralColumnConverter(column_name))
converter = _DefaultGeneralColumnConverter(column_name)
self.converters.append(converter)

def convert(self, raw_rows: Iterable[sqlalchemy.Row]) -> GeneralResultPage:
rows = [tuple(cvt.convert(row) for cvt in self.converters) for row in raw_rows]
@@ -422,3 +440,47 @@ def __init__(self, name: str, db: Database):
def convert(self, row: sqlalchemy.Row) -> Any:
timespan = self.timespan_class.extract(row._mapping, self.name)
return timespan


class _CachedRecordGeneralColumnConverter(_GeneralColumnConverter):
"""Helper for converting result row into a field value for cached
dimension records.
Parameters
----------
element : `DimensionElement`
Dimension element, must be of cached type.
field : `str`
Name of the field to extract from the dimension record.
cache : `DimensionRecordCache`
Cache for dimension records.
"""

def __init__(self, element: DimensionElement, field: str, cache: DimensionRecordCache) -> None:
self._record_converter = _CachedDimensionRecordRowConverter(element, cache)
self._field = field

def convert(self, row: sqlalchemy.Row) -> Any:
record = self._record_converter.convert(row)
return getattr(record, self._field)


class _SkypixRecordGeneralColumnConverter(_GeneralColumnConverter):
"""Helper for converting result row into a field value for skypix
dimension records.
Parameters
----------
element : `SkyPixDimension`
Dimension element.
field : `str`
Name of the field to extract from the dimension record.
"""

def __init__(self, element: SkyPixDimension, field: str) -> None:
self._record_converter = _SkypixDimensionRecordRowConverter(element)
self._field = field

Check warning on line 482 in python/lsst/daf/butler/direct_query_driver/_result_page_converter.py

Codecov / codecov/patch

python/lsst/daf/butler/direct_query_driver/_result_page_converter.py#L481-L482

Added lines #L481 - L482 were not covered by tests

def convert(self, row: sqlalchemy.Row) -> Any:
record = self._record_converter.convert(row)
return getattr(record, self._field)

Check warning on line 486 in python/lsst/daf/butler/direct_query_driver/_result_page_converter.py

Codecov / codecov/patch

python/lsst/daf/butler/direct_query_driver/_result_page_converter.py#L485-L486

Added lines #L485 - L486 were not covered by tests
53 changes: 43 additions & 10 deletions python/lsst/daf/butler/queries/_general_query_results.py
Original file line number Diff line number Diff line change
@@ -35,11 +35,11 @@

from .._dataset_ref import DatasetRef
from .._dataset_type import DatasetType
from ..dimensions import DataCoordinate, DimensionGroup
from ..dimensions import DataCoordinate, DimensionElement, DimensionGroup, DimensionRecord
from ._base import QueryResultsBase
from .driver import QueryDriver
from .result_specs import GeneralResultSpec
from .tree import QueryTree
from .tree import QueryTree, ResultColumn


class GeneralResultTuple(NamedTuple):
@@ -99,9 +99,9 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
fields (separated from dataset type name by dot).
"""
for page in self._driver.execute(self._spec, self._tree):
columns = tuple(str(column) for column in page.spec.get_result_columns())
columns = tuple(str(column) for column in page.spec.get_all_result_columns())
for row in page.rows:
yield dict(zip(columns, row))
yield dict(zip(columns, row, strict=True))

def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTuple]:
"""Iterate over result rows and return data coordinate, and dataset
@@ -125,14 +125,10 @@ def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTupl
run_key = f"{dataset_type.name}.run"
dataset_keys.append((dataset_type, dimensions, id_key, run_key))
for row in self:
values = tuple(
row[key] for key in itertools.chain(all_dimensions.required, all_dimensions.implied)
)
data_coordinate = DataCoordinate.from_full_values(all_dimensions, values)
data_coordinate = self._make_data_id(row, all_dimensions)
refs = []
for dataset_type, dimensions, id_key, run_key in dataset_keys:
values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied))
data_id = DataCoordinate.from_full_values(dimensions, values)
data_id = self._make_data_id(row, dimensions)
refs.append(DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key]))
yield GeneralResultTuple(data_id=data_coordinate, refs=refs, raw_row=row)

@@ -141,6 +137,19 @@ def dimensions(self) -> DimensionGroup:
# Docstring inherited
return self._spec.dimensions

@property
def has_dimension_records(self) -> bool:
"""Whether all data IDs in this iterable contain dimension records."""
return self._spec.include_dimension_records

def with_dimension_records(self) -> GeneralQueryResults:
"""Return a results object for which `has_dimension_records` is
`True`.
"""
if self.has_dimension_records:
return self

Check warning on line 150 in python/lsst/daf/butler/queries/_general_query_results.py

Codecov / codecov/patch

python/lsst/daf/butler/queries/_general_query_results.py#L150

Added line #L150 was not covered by tests
return self._copy(tree=self._tree, include_dimension_records=True)

def count(self, *, exact: bool = True, discard: bool = False) -> int:
# Docstring inherited.
return self._driver.count(self._tree, self._spec, exact=exact, discard=discard)
@@ -152,3 +161,27 @@ def _copy(self, tree: QueryTree, **kwargs: Any) -> GeneralQueryResults:
def _get_datasets(self) -> frozenset[str]:
# Docstring inherited.
return frozenset(self._spec.dataset_fields)

def _make_data_id(self, row: dict[str, Any], dimensions: DimensionGroup) -> DataCoordinate:
values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied))
data_coordinate = DataCoordinate.from_full_values(dimensions, values)
if self.has_dimension_records:
records = {
name: self._make_dimension_record(row, dimensions.universe[name])
for name in dimensions.elements
}
data_coordinate = data_coordinate.expanded(records)
return data_coordinate

def _make_dimension_record(self, row: dict[str, Any], element: DimensionElement) -> DimensionRecord:
column_map = list(
zip(
element.schema.dimensions.names,
element.dimensions.names,
)
)
for field in element.schema.remainder.names:
column_map.append((field, str(ResultColumn(element.name, field))))
d = {k: row[v] for k, v in column_map}
record_cls = element.RecordClass
return record_cls(**d)
2 changes: 1 addition & 1 deletion python/lsst/daf/butler/queries/driver.py
Original file line number Diff line number Diff line change
@@ -117,7 +117,7 @@ class GeneralResultPage:
spec: GeneralResultSpec

# Raw tabular data, with columns in the same order as
# spec.get_result_columns().
# spec.get_all_result_columns().
rows: list[tuple[Any, ...]]


32 changes: 32 additions & 0 deletions python/lsst/daf/butler/queries/result_specs.py
Original file line number Diff line number Diff line change
@@ -213,6 +213,11 @@ class GeneralResultSpec(ResultSpecBase):
dataset_fields: Mapping[str, set[DatasetFieldName]]
"""Dataset fields included in this query."""

include_dimension_records: bool = False
"""Whether to include fields for all dimension records, in addition to
explicitly specified in `dimension_fields`.
"""

find_first: bool
"""Whether this query requires find-first resolution for a dataset.
@@ -241,6 +246,33 @@ def get_result_columns(self) -> ColumnSet:
result.dimension_fields[element_name].update(fields_for_element)
for dataset_type, fields_for_dataset in self.dataset_fields.items():
result.dataset_fields[dataset_type].update(fields_for_dataset)
if self.include_dimension_records:
# This only adds record fields for non-cached and non-skypix
# elements, this is what we want when generating query. We could
# potentially add those too but it may make queries slower, so
# instead we query cached dimension records separately and add them
# to the result page in the page converter.
_add_dimension_records_to_column_set(self.dimensions, result)
return result

def get_all_result_columns(self) -> ColumnSet:
"""Return all columns that have to appear in the result. This includes
columns for all dimension records for all dimensions if
``include_dimension_records`` is `True`.
Returns
-------
columns : `ColumnSet`
Full column set.
"""
dimensions = self.dimensions
result = self.get_result_columns()
if self.include_dimension_records:
for element_name in dimensions.elements:
element = dimensions.universe[element_name]
# Non-cached dimensions are already there, but it does not harm
# to add them again.
result.dimension_fields[element_name].update(element.schema.remainder.names)
return result

@pydantic.model_validator(mode="after")
17 changes: 15 additions & 2 deletions python/lsst/daf/butler/remote_butler/_query_driver.py
Original file line number Diff line number Diff line change
@@ -257,12 +257,25 @@ def _convert_query_result_page(

def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel) -> GeneralResultPage:
"""Convert GeneralResultModel to a general result page."""
columns = spec.get_result_columns()
columns = spec.get_all_result_columns()
# Verify that column list that we received from server matches local
# expectations (mismatch could result from different versions). Older
# server may not know about `model.columns` in that case it will be empty.
# If `model.columns` is empty then `zip(strict=True)` below will fail if
# column count is different (column names are not checked in that case).
if model.columns:
expected_column_names = [str(column) for column in columns]
if expected_column_names != model.columns:
raise ValueError(

Check warning on line 269 in python/lsst/daf/butler/remote_butler/_query_driver.py

Codecov / codecov/patch

python/lsst/daf/butler/remote_butler/_query_driver.py#L269

Added line #L269 was not covered by tests
"Inconsistent columns in general result -- "
f"server columns: {model.columns}, expected: {expected_column_names}"
)

serializers = [
columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns
]
rows = [
tuple(serializer.deserialize(value) for value, serializer in zip(row, serializers))
tuple(serializer.deserialize(value) for value, serializer in zip(row, serializers, strict=True))
for row in model.rows
]
return GeneralResultPage(spec=spec, rows=rows)
Original file line number Diff line number Diff line change
@@ -79,11 +79,12 @@ def convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResult

def _convert_general_result(page: GeneralResultPage) -> GeneralResultModel:
"""Convert GeneralResultPage to a serializable model."""
columns = page.spec.get_result_columns()
columns = page.spec.get_all_result_columns()
serializers = [
columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns
]
rows = [
tuple(serializer.serialize(value) for value, serializer in zip(row, serializers)) for row in page.rows
tuple(serializer.serialize(value) for value, serializer in zip(row, serializers, strict=True))
for row in page.rows
]
return GeneralResultModel(rows=rows)
return GeneralResultModel(rows=rows, columns=[str(column) for column in columns])
3 changes: 3 additions & 0 deletions python/lsst/daf/butler/remote_butler/server_models.py
Original file line number Diff line number Diff line change
@@ -313,6 +313,9 @@ class GeneralResultModel(pydantic.BaseModel):

type: Literal["general"] = "general"
rows: list[tuple[Any, ...]]
# List of column names, default is used for compatibility with older
# servers that do not set this field.
columns: list[str] = pydantic.Field(default_factory=list)


class QueryErrorResultModel(pydantic.BaseModel):
42 changes: 42 additions & 0 deletions python/lsst/daf/butler/tests/butler_queries.py
Original file line number Diff line number Diff line change
@@ -436,7 +436,9 @@ def test_general_query(self) -> None:
self.assertEqual(len(row_tuple.refs), 1)
self.assertEqual(row_tuple.refs[0].datasetType, flat)
self.assertTrue(row_tuple.refs[0].dataId.hasFull())
self.assertFalse(row_tuple.refs[0].dataId.hasRecords())
self.assertTrue(row_tuple.data_id.hasFull())
self.assertFalse(row_tuple.data_id.hasRecords())
self.assertEqual(row_tuple.data_id.dimensions, dimensions)
self.assertEqual(row_tuple.raw_row["flat.run"], "imported_g")

@@ -511,6 +513,46 @@ def test_general_query(self) -> None:
{Timespan(t1, t2), Timespan(t2, t3), Timespan(t3, None), Timespan.makeEmpty(), None},
)

dimensions = butler.dimensions["detector"].minimal_group

# Include dimension records into query.
with butler.query() as query:
query = query.join_dimensions(dimensions)
result = query.general(dimensions).order_by("detector")
rows = list(result.with_dimension_records())
self.assertEqual(
rows[0],
{
"instrument": "Cam1",
"detector": 1,
"instrument.visit_max": 1024,
"instrument.visit_system": 1,
"instrument.exposure_max": 512,
"instrument.detector_max": 4,
"instrument.class_name": "lsst.pipe.base.Instrument",
"detector.full_name": "Aa",
"detector.name_in_raft": "a",
"detector.raft": "A",
"detector.purpose": "SCIENCE",
},
)

dimensions = butler.dimensions.conform(["detector", "physical_filter"])

# DataIds should come with records.
with butler.query() as query:
query = query.join_dataset_search("flat", "imported_g")
result = query.general(dimensions, dataset_fields={"flat": ...}, find_first=True).order_by(
"detector"
)
result = result.with_dimension_records()
row_tuples = list(result.iter_tuples(flat))
self.assertEqual(len(row_tuples), 3)
for row_tuple in row_tuples:
self.assertTrue(row_tuple.data_id.hasRecords())
self.assertEqual(len(row_tuple.refs), 1)
self.assertTrue(row_tuple.refs[0].dataId.hasRecords())

def test_query_ingest_date(self) -> None:
"""Test general query returning ingest_date field."""
before_ingest = astropy.time.Time.now()

0 comments on commit 4bbe699

Please sign in to comment.