Skip to content

Commit

Permalink
Use more efficient dimension records representation.
Browse files Browse the repository at this point in the history
GeneralResultPage now stores cached dimension records in a separate structure
which avoids massive duplication of the records.
  • Loading branch information
andy-slac committed Dec 19, 2024
1 parent 8be90c3 commit 8796b6f
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -330,35 +330,44 @@ def __init__(self, spec: GeneralResultSpec, ctx: ResultPageConverterContext) ->
# columns returned by the query we have to add columns from dimension
# records that are not returned by the query. These columns belong to
# either cached or skypix dimensions.
query_result_columns = set(spec.get_result_columns())
output_columns = spec.get_all_result_columns()
columns = spec.get_result_columns()
universe = spec.dimensions.universe
self.converters: list[_GeneralColumnConverter] = []
for column in output_columns:
self.record_converters: dict[DimensionElement, _DimensionRecordRowConverter] = {}
for column in columns:
column_name = qt.ColumnSet.get_qualified_name(column.logical_table, column.field)
converter: _GeneralColumnConverter
if column not in query_result_columns and column.field is not None:
# This must be a field from a cached dimension record or
# skypix record.
assert isinstance(column.logical_table, str), "Do not expect AnyDatasetType here"
element = universe[column.logical_table]
if isinstance(element, SkyPixDimension):
converter = _SkypixRecordGeneralColumnConverter(element, column.field)
else:
converter = _CachedRecordGeneralColumnConverter(
element, column.field, ctx.dimension_record_cache
)
elif column.field == TimespanDatabaseRepresentation.NAME:
if column.field == TimespanDatabaseRepresentation.NAME:
converter = _TimespanGeneralColumnConverter(column_name, ctx.db)
elif column.field == "ingest_date":
converter = _TimestampGeneralColumnConverter(column_name)
else:
converter = _DefaultGeneralColumnConverter(column_name)
self.converters.append(converter)

if spec.include_dimension_records:
universe = self.spec.dimensions.universe
for element_name in self.spec.dimensions.elements:
element = universe[element_name]
if isinstance(element, SkyPixDimension):
self.record_converters[element] = _SkypixDimensionRecordRowConverter(element)

Check warning on line 353 in python/lsst/daf/butler/direct_query_driver/_result_page_converter.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/direct_query_driver/_result_page_converter.py#L353

Added line #L353 was not covered by tests
elif element.is_cached:
self.record_converters[element] = _CachedDimensionRecordRowConverter(
element, ctx.dimension_record_cache
)

def convert(self, raw_rows: Iterable[sqlalchemy.Row]) -> GeneralResultPage:
rows = [tuple(cvt.convert(row) for cvt in self.converters) for row in raw_rows]
return GeneralResultPage(spec=self.spec, rows=rows)
rows = []
dimension_records = None
if self.spec.include_dimension_records:
dimension_records = {element: DimensionRecordSet(element) for element in self.record_converters}
for row in raw_rows:
rows.append(tuple(cvt.convert(row) for cvt in self.converters))
if dimension_records:
for element, converter in self.record_converters.items():
dimension_records[element].add(converter.convert(row))

return GeneralResultPage(spec=self.spec, rows=rows, dimension_records=dimension_records)


class _GeneralColumnConverter:
Expand Down Expand Up @@ -440,47 +449,3 @@ def __init__(self, name: str, db: Database):
def convert(self, row: sqlalchemy.Row) -> Any:
timespan = self.timespan_class.extract(row._mapping, self.name)
return timespan


class _CachedRecordGeneralColumnConverter(_GeneralColumnConverter):
"""Helper for converting result row into a field value for cached
dimension records.
Parameters
----------
element : `DimensionElement`
Dimension element, must be of cached type.
field : `str`
Name of the field to extract from the dimension record.
cache : `DimensionRecordCache`
Cache for dimension records.
"""

def __init__(self, element: DimensionElement, field: str, cache: DimensionRecordCache) -> None:
self._record_converter = _CachedDimensionRecordRowConverter(element, cache)
self._field = field

def convert(self, row: sqlalchemy.Row) -> Any:
record = self._record_converter.convert(row)
return getattr(record, self._field)


class _SkypixRecordGeneralColumnConverter(_GeneralColumnConverter):
"""Helper for converting result row into a field value for skypix
dimension records.
Parameters
----------
element : `SkyPixDimension`
Dimension element.
field : `str`
Name of the field to extract from the dimension record.
"""

def __init__(self, element: SkyPixDimension, field: str) -> None:
self._record_converter = _SkypixDimensionRecordRowConverter(element)
self._field = field

def convert(self, row: sqlalchemy.Row) -> Any:
record = self._record_converter.convert(row)
return getattr(record, self._field)
69 changes: 54 additions & 15 deletions python/lsst/daf/butler/queries/_general_query_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from .._dataset_ref import DatasetRef
from .._dataset_type import DatasetType
from ..dimensions import DataCoordinate, DimensionElement, DimensionGroup, DimensionRecord
from ..dimensions import DataCoordinate, DimensionElement, DimensionGroup, DimensionRecord, DimensionRecordSet
from ._base import QueryResultsBase
from .driver import QueryDriver
from .result_specs import GeneralResultSpec
Expand Down Expand Up @@ -99,9 +99,13 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
fields (separated from dataset type name by dot).
"""
for page in self._driver.execute(self._spec, self._tree):
columns = tuple(str(column) for column in page.spec.get_all_result_columns())
columns = tuple(str(column) for column in page.spec.get_result_columns())
for row in page.rows:
yield dict(zip(columns, row, strict=True))
result = dict(zip(columns, row, strict=True))
if page.dimension_records:
records = self._get_cached_dimension_records(result, page.dimension_records)
self._add_dimension_records(result, records)
yield result

def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTuple]:
"""Iterate over result rows and return data coordinate, and dataset
Expand All @@ -124,13 +128,21 @@ def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTupl
id_key = f"{dataset_type.name}.dataset_id"
run_key = f"{dataset_type.name}.run"
dataset_keys.append((dataset_type, dimensions, id_key, run_key))
for row in self:
data_coordinate = self._make_data_id(row, all_dimensions)
refs = []
for dataset_type, dimensions, id_key, run_key in dataset_keys:
data_id = data_coordinate.subset(dimensions)
refs.append(DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key]))
yield GeneralResultTuple(data_id=data_coordinate, refs=refs, raw_row=row)
for page in self._driver.execute(self._spec, self._tree):
columns = tuple(str(column) for column in page.spec.get_result_columns())
for page_row in page.rows:
row = dict(zip(columns, page_row, strict=True))
if page.dimension_records:
cached_records = self._get_cached_dimension_records(row, page.dimension_records)
self._add_dimension_records(row, cached_records)
else:
cached_records = {}
data_coordinate = self._make_data_id(row, all_dimensions, cached_records)
refs = []
for dataset_type, dimensions, id_key, run_key in dataset_keys:
data_id = data_coordinate.subset(dimensions)
refs.append(DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key]))
yield GeneralResultTuple(data_id=data_coordinate, refs=refs, raw_row=row)

@property
def dimensions(self) -> DimensionGroup:
Expand Down Expand Up @@ -162,14 +174,22 @@ def _get_datasets(self) -> frozenset[str]:
# Docstring inherited.
return frozenset(self._spec.dataset_fields)

def _make_data_id(self, row: dict[str, Any], dimensions: DimensionGroup) -> DataCoordinate:
def _make_data_id(
self,
row: dict[str, Any],
dimensions: DimensionGroup,
cached_row_records: dict[DimensionElement, DimensionRecord],
) -> DataCoordinate:
values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied))
data_coordinate = DataCoordinate.from_full_values(dimensions, values)
if self.has_dimension_records:
records = {
name: self._make_dimension_record(row, dimensions.universe[name])
for name in dimensions.elements
}
records = {}
for name in dimensions.elements:
element = dimensions.universe[name]
record = cached_row_records.get(element)
if record is None:
record = self._make_dimension_record(row, dimensions.universe[name])

Check warning on line 191 in python/lsst/daf/butler/queries/_general_query_results.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/queries/_general_query_results.py#L191

Added line #L191 was not covered by tests
records[name] = record
data_coordinate = data_coordinate.expanded(records)
return data_coordinate

Expand All @@ -185,3 +205,22 @@ def _make_dimension_record(self, row: dict[str, Any], element: DimensionElement)
d = {k: row[v] for k, v in column_map}
record_cls = element.RecordClass
return record_cls(**d)

Check warning on line 207 in python/lsst/daf/butler/queries/_general_query_results.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/queries/_general_query_results.py#L204-L207

Added lines #L204 - L207 were not covered by tests

def _get_cached_dimension_records(
self, row: dict[str, Any], dimension_records: dict[DimensionElement, DimensionRecordSet]
) -> dict[DimensionElement, DimensionRecord]:
"""Find cached dimension records matching this row."""
records = {}
for element, element_records in dimension_records.items():
required_values = tuple(row[key] for key in element.required.names)
records[element] = element_records.find_with_required_values(required_values)
return records

def _add_dimension_records(
self, row: dict[str, Any], records: dict[DimensionElement, DimensionRecord]
) -> None:
"""Extend row with the fields from cached dimension records."""
for element, record in records.items():
for name, value in record.toDict().items():
if name not in element.schema.required.names:
row[f"{element.name}.{name}"] = value
7 changes: 6 additions & 1 deletion python/lsst/daf/butler/queries/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from ..dimensions import (
DataCoordinate,
DataIdValue,
DimensionElement,
DimensionGroup,
DimensionRecord,
DimensionRecordSet,
Expand Down Expand Up @@ -117,9 +118,13 @@ class GeneralResultPage:
spec: GeneralResultSpec

# Raw tabular data, with columns in the same order as
# spec.get_all_result_columns().
# spec.get_result_columns().
rows: list[tuple[Any, ...]]

# This map contains dimension records for cached and skypix elements,
# and only when spec.include_dimension_records is True.
dimension_records: dict[DimensionElement, DimensionRecordSet] | None


ResultPage: TypeAlias = Union[
DataCoordinateResultPage, DimensionRecordResultPage, DatasetRefResultPage, GeneralResultPage
Expand Down
27 changes: 3 additions & 24 deletions python/lsst/daf/butler/queries/result_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,33 +248,12 @@ def get_result_columns(self) -> ColumnSet:
result.dataset_fields[dataset_type].update(fields_for_dataset)
if self.include_dimension_records:
# This only adds record fields for non-cached and non-skypix
# elements, this is what we want when generating query. We could
# potentially add those too but it may make queries slower, so
# instead we query cached dimension records separately and add them
# to the result page in the page converter.
# elements, this is what we want when generating query. When
# `include_dimension_records` is True, dimension records for cached
# and skypix elements are added to result pages by page converter.
_add_dimension_records_to_column_set(self.dimensions, result)
return result

def get_all_result_columns(self) -> ColumnSet:
"""Return all columns that have to appear in the result. This includes
columns for all dimension records for all dimensions if
``include_dimension_records`` is `True`.
Returns
-------
columns : `ColumnSet`
Full column set.
"""
dimensions = self.dimensions
result = self.get_result_columns()
if self.include_dimension_records:
for element_name in dimensions.elements:
element = dimensions.universe[element_name]
# Non-cached dimensions are already there, but it does not harm
# to add them again.
result.dimension_fields[element_name].update(element.schema.remainder.names)
return result

@pydantic.model_validator(mode="after")
def _validate(self) -> GeneralResultSpec:
if self.find_first and len(self.dataset_fields) != 1:
Expand Down
39 changes: 26 additions & 13 deletions python/lsst/daf/butler/remote_butler/_query_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,14 @@

from ...butler import Butler
from .._dataset_type import DatasetType
from ..dimensions import DataCoordinate, DataIdValue, DimensionGroup, DimensionRecord, DimensionUniverse
from ..dimensions import (
DataCoordinate,
DataIdValue,
DimensionGroup,
DimensionRecord,
DimensionRecordSet,
DimensionUniverse,
)
from ..queries.driver import (
DataCoordinateResultPage,
DatasetRefResultPage,
Expand Down Expand Up @@ -257,25 +264,31 @@ def _convert_query_result_page(

def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel) -> GeneralResultPage:
"""Convert GeneralResultModel to a general result page."""
columns = spec.get_all_result_columns()
# Verify that column list that we received from server matches local
# expectations (mismatch could result from different versions). Older
# server may not know about `model.columns` in that case it will be empty.
# If `model.columns` is empty then `zip(strict=True)` below will fail if
# column count is different (column names are not checked in that case).
if model.columns:
expected_column_names = [str(column) for column in columns]
if expected_column_names != model.columns:
if spec.include_dimension_records:
# dimension_records must not be None when `include_dimension_records`
# is True, but it will be None if remote server was not upgraded.
if model.dimension_records is None:
raise ValueError(

Check warning on line 271 in python/lsst/daf/butler/remote_butler/_query_driver.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/remote_butler/_query_driver.py#L271

Added line #L271 was not covered by tests
"Inconsistent columns in general result -- "
f"server columns: {model.columns}, expected: {expected_column_names}"
"Missing dimension records in general result -- " "it is likely that server needs an upgrade."
)

columns = spec.get_result_columns()
serializers = [
columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns
]
rows = [
tuple(serializer.deserialize(value) for value, serializer in zip(row, serializers, strict=True))
for row in model.rows
]
return GeneralResultPage(spec=spec, rows=rows)

universe = spec.dimensions.universe
dimension_records = None
if model.dimension_records is not None:
dimension_records = {}
for name, records in model.dimension_records.items():
element = universe[name]
dimension_records[element] = DimensionRecordSet(
element, (DimensionRecord.from_simple(r, universe) for r in records)
)

return GeneralResultPage(spec=spec, rows=rows, dimension_records=dimension_records)
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,18 @@ def convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResult

def _convert_general_result(page: GeneralResultPage) -> GeneralResultModel:
"""Convert GeneralResultPage to a serializable model."""
columns = page.spec.get_all_result_columns()
columns = page.spec.get_result_columns()
serializers = [
columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns
]
rows = [
tuple(serializer.serialize(value) for value, serializer in zip(row, serializers, strict=True))
for row in page.rows
]
return GeneralResultModel(rows=rows, columns=[str(column) for column in columns])
dimension_records = None
if page.dimension_records is not None:
dimension_records = {
element.name: [record.to_simple() for record in records]
for element, records in page.dimension_records.items()
}
return GeneralResultModel(rows=rows, dimension_records=dimension_records)
5 changes: 3 additions & 2 deletions python/lsst/daf/butler/remote_butler/server_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,10 @@ class GeneralResultModel(pydantic.BaseModel):

type: Literal["general"] = "general"
rows: list[tuple[Any, ...]]
# List of column names, default is used for compatibility with older
# Dimension records indexed by element name, only cached and skypix
# elements are included. Default is used for compatibility with older
# servers that do not set this field.
columns: list[str] = pydantic.Field(default_factory=list)
dimension_records: dict[str, list[SerializedDimensionRecord]] | None = None


class QueryErrorResultModel(pydantic.BaseModel):
Expand Down

0 comments on commit 8796b6f

Please sign in to comment.