Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-46248: Few small improvements for the new query system #1146

Merged
merged 5 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions python/lsst/daf/butler/direct_query_driver/_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def execute(self, result_spec: ResultSpec, tree: qt.QueryTree) -> Iterator[Resul
final_columns=result_spec.get_result_columns(),
order_by=result_spec.order_by,
find_first_dataset=result_spec.find_first_dataset,
allow_duplicate_overlaps=result_spec.allow_duplicate_overlaps,
)
sql_select, sql_columns = builder.finish_select()
if result_spec.order_by:
Expand Down Expand Up @@ -290,12 +291,15 @@ def materialize(
tree: qt.QueryTree,
dimensions: DimensionGroup,
datasets: frozenset[str],
allow_duplicate_overlaps: bool = False,
key: qt.MaterializationKey | None = None,
) -> qt.MaterializationKey:
# Docstring inherited.
if self._exit_stack is None:
raise RuntimeError("QueryDriver context must be entered before 'materialize' is called.")
plan = self.build_query(tree, qt.ColumnSet(dimensions))
plan = self.build_query(
tree, qt.ColumnSet(dimensions), allow_duplicate_overlaps=allow_duplicate_overlaps
)
# Current implementation ignores 'datasets' aside from remembering
# them, because figuring out what to put in the temporary table for
# them is tricky, especially if calibration collections are involved.
Expand All @@ -311,7 +315,9 @@ def materialize(
#
sql_select, _ = plan.finish_select(return_columns=False)
table = self._exit_stack.enter_context(
self.db.temporary_table(make_table_spec(plan.final_columns, self.db, plan.postprocessing))
self.db.temporary_table(
make_table_spec(plan.final_columns, self.db, plan.postprocessing, make_indices=True)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall you saying earlier that adding the indexes before doing the inserts sometimes made the inserts much more expensive. Did that get resolved by other changes, or is it just an overall win even if it's occasionally worse.

And would it be worth it to try to move index creation to just after insert in this method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a false alarm - I have added indices and filling temp table became terribly slow, the real reason was that query planner went into a dumb state on the same exact query (this was before I dropped DISTINCT). With disabled DISTINCT I have not noticed any significant slow down due to indices.
It is true that for bulk data insert it is better to load the data before adding indices. I tried to implement it, but it is not trivial with our code now, the simple idea did not work, and I did not want to start restructuring a bunch of code just for that. Still, if we see it later, it should be possible to delay index creation until after filling the data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just in case we want to do index creation separately later - my initial idea was that I run the same _convertTableSpec but do not add Index instances to Table constructor arguments, instead the indices are returned as a separate list so you can later iterate and create them manually. It happens, that even if you do not add an Index to the Table() arguments, the metadata already knows about all indices for that table (as Index is created for a specific table) and will create that index anyways. So delaying just did not work.

)
self.db.insert(table, select=sql_select)
if key is None:
Expand Down Expand Up @@ -401,7 +407,7 @@ def count(

def any(self, tree: qt.QueryTree, *, execute: bool, exact: bool) -> bool:
# Docstring inherited.
builder = self.build_query(tree, qt.ColumnSet(tree.dimensions))
builder = self.build_query(tree, qt.ColumnSet(tree.dimensions), allow_duplicate_overlaps=True)
if not all(d.collection_records for d in builder.joins_analysis.datasets.values()):
return False
if not execute:
Expand Down Expand Up @@ -447,6 +453,7 @@ def build_query(
order_by: Iterable[qt.OrderExpression] = (),
find_first_dataset: str | qt.AnyDatasetType | None = None,
analyze_only: bool = False,
allow_duplicate_overlaps: bool = False,
) -> QueryBuilder:
"""Convert a query description into a nearly-complete builder object
for the SQL version of that query.
Expand All @@ -470,6 +477,9 @@ def build_query(
builder, but do not call methods that build its SQL form. This can
be useful for obtaining diagnostic information about the query that
would be generated.
allow_duplicate_overlaps : `bool`, optional
If set to `True` then query will be allowed to generate
non-distinct rows for spatial overlaps.
Returns
-------
Expand Down Expand Up @@ -542,7 +552,7 @@ def build_query(
# SqlSelectBuilder and Postprocessing with spatial/temporal constraints
# potentially transformed by the dimensions manager (but none of the
# rest of the analysis reflected in that SqlSelectBuilder).
query_tree_analysis = self._analyze_query_tree(tree)
query_tree_analysis = self._analyze_query_tree(tree, allow_duplicate_overlaps)
# The "projection" columns differ from the final columns by not
# omitting any dimension keys (this keeps queries for different result
# types more similar during construction), including any columns needed
Expand Down Expand Up @@ -589,7 +599,7 @@ def build_query(
builder.apply_find_first(self)
return builder

def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis:
def _analyze_query_tree(self, tree: qt.QueryTree, allow_duplicate_overlaps: bool) -> QueryTreeAnalysis:
"""Analyze a `.queries.tree.QueryTree` as the first step in building
a SQL query.
Expand All @@ -603,6 +613,9 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis:
tree_analysis : `QueryTreeAnalysis`
Struct containing additional information need to build the joins
stage of a query.
allow_duplicate_overlaps : `bool`, optional
If set to `True` then query will be allowed to generate
non-distinct rows for spatial overlaps.
Notes
-----
Expand Down Expand Up @@ -632,6 +645,7 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis:
tree.predicate,
tree.get_joined_dimension_groups(),
collection_analysis.calibration_dataset_types,
allow_duplicate_overlaps,
)
# Extract the data ID implied by the predicate; we can use the governor
# dimensions in that to constrain the collections we search for
Expand Down Expand Up @@ -799,13 +813,22 @@ def apply_initial_query_joins(
select_builder.joins, materialization_key, materialization_dimensions
)
)
# Process dataset joins (not including any union dataset).
for dataset_search in joins_analysis.datasets.values():
self.join_dataset_search(
select_builder.joins,
dataset_search,
joins_analysis.columns.dataset_fields[dataset_search.name],
)
# Process dataset joins (not including any union dataset). Datasets
# searches included in materialization can be skipped unless we need
# something from their tables.
materialized_datasets = set()
for m_state in self._materializations.values():
materialized_datasets.update(m_state.datasets)
for dataset_type_name, dataset_search in joins_analysis.datasets.items():
if (
dataset_type_name not in materialized_datasets
or dataset_type_name in select_builder.columns.dataset_fields
):
self.join_dataset_search(
select_builder.joins,
dataset_search,
joins_analysis.columns.dataset_fields[dataset_search.name],
)
# Join in dimension element tables that we know we need relationships
# or columns from.
for element in joins_analysis.iter_mandatory(union_dataset_dimensions):
Expand Down
26 changes: 24 additions & 2 deletions python/lsst/daf/butler/direct_query_driver/_sql_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import sqlalchemy

from .. import ddl
from ..dimensions import DimensionGroup
from ..dimensions._group import SortedSequenceSet
from ..nonempty_mapping import NonemptyMapping
from ..queries import tree as qt
from ._postprocessing import Postprocessing
Expand Down Expand Up @@ -638,7 +640,7 @@ def to_select_builder(


def make_table_spec(
columns: qt.ColumnSet, db: Database, postprocessing: Postprocessing | None
columns: qt.ColumnSet, db: Database, postprocessing: Postprocessing | None, *, make_indices: bool = False
) -> ddl.TableSpec:
"""Make a specification that can be used to create a table to store
this query's outputs.
Expand All @@ -652,18 +654,22 @@ def make_table_spec(
postprocessing : `Postprocessing`
Struct representing post-query processing in Python, which may
require additional columns in the query results.
make_indices : `bool`, optional
If `True` add indices for groups of columns.

Returns
-------
spec : `.ddl.TableSpec`
Table specification for this query's result columns (including
those from `postprocessing` and `SqlJoinsBuilder.special`).
"""
indices = _make_table_indices(columns.dimensions) if make_indices else []
results = ddl.TableSpec(
[
columns.get_column_spec(logical_table, field).to_sql_spec(name_shrinker=db.name_shrinker)
for logical_table, field in columns
]
],
indexes=indices,
)
if postprocessing:
for element in postprocessing.iter_missing(columns):
Expand All @@ -679,3 +685,19 @@ def make_table_spec(
ddl.FieldSpec(name=SqlSelectBuilder.EMPTY_COLUMNS_NAME, dtype=SqlSelectBuilder.EMPTY_COLUMNS_TYPE)
)
return results


def _make_table_indices(dimensions: DimensionGroup) -> list[ddl.IndexSpec]:

index_columns: list[SortedSequenceSet] = []
for dimension in dimensions.required:
minimal_group = dimensions.universe[dimension].minimal_group.required

for idx in range(len(index_columns)):
if index_columns[idx] <= minimal_group:
index_columns[idx] = minimal_group
break
else:
index_columns.append(minimal_group)

return [ddl.IndexSpec(*columns) for columns in index_columns]
24 changes: 20 additions & 4 deletions python/lsst/daf/butler/queries/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ def __init__(self, driver: QueryDriver, tree: QueryTree | None = None):
tree = make_identity_query_tree(driver.universe)
super().__init__(driver, tree)

# If ``_allow_duplicate_overlaps`` is set to `True` then query will be
# allowed to generate non-distinct rows for spatial overlaps. This is
# not a part of public API for now, to be used by graph builder as
# optimization.
self._allow_duplicate_overlaps: bool = False

@property
def constraint_dataset_types(self) -> Set[str]:
"""The names of all dataset types joined into the query.
Expand Down Expand Up @@ -218,7 +224,11 @@ def data_ids(
dimensions = self._driver.universe.conform(dimensions)
if not dimensions <= self._tree.dimensions:
tree = tree.join_dimensions(dimensions)
result_spec = DataCoordinateResultSpec(dimensions=dimensions, include_dimension_records=False)
result_spec = DataCoordinateResultSpec(
dimensions=dimensions,
include_dimension_records=False,
allow_duplicate_overlaps=self._allow_duplicate_overlaps,
)
return DataCoordinateQueryResults(self._driver, tree, result_spec)

def datasets(
Expand Down Expand Up @@ -284,6 +294,7 @@ def datasets(
storage_class_name=storage_class_name,
include_dimension_records=False,
find_first=find_first,
allow_duplicate_overlaps=self._allow_duplicate_overlaps,
)
return DatasetRefQueryResults(self._driver, tree=query._tree, spec=spec)

Expand All @@ -308,7 +319,9 @@ def dimension_records(self, element: str) -> DimensionRecordQueryResults:
tree = self._tree
if element not in tree.dimensions.elements:
tree = tree.join_dimensions(self._driver.universe[element].minimal_group)
result_spec = DimensionRecordResultSpec(element=self._driver.universe[element])
result_spec = DimensionRecordResultSpec(
element=self._driver.universe[element], allow_duplicate_overlaps=self._allow_duplicate_overlaps
)
return DimensionRecordQueryResults(self._driver, tree, result_spec)

def general(
Expand Down Expand Up @@ -445,6 +458,7 @@ def general(
dimension_fields=dimension_fields_dict,
dataset_fields=dataset_fields_dict,
find_first=find_first,
allow_duplicate_overlaps=self._allow_duplicate_overlaps,
)
return GeneralQueryResults(self._driver, tree=tree, spec=result_spec)

Expand Down Expand Up @@ -495,7 +509,9 @@ def materialize(
dimensions = self._tree.dimensions
else:
dimensions = self._driver.universe.conform(dimensions)
key = self._driver.materialize(self._tree, dimensions, datasets)
key = self._driver.materialize(
self._tree, dimensions, datasets, allow_duplicate_overlaps=self._allow_duplicate_overlaps
)
tree = make_identity_query_tree(self._driver.universe).join_materialization(
key, dimensions=dimensions
)
Expand All @@ -508,7 +524,7 @@ def materialize(
"Expand the dimensions or drop this dataset type in the arguments to materialize to "
"avoid this error."
)
tree = tree.join_dataset(dataset_type_name, self._tree.datasets[dataset_type_name])
tree = tree.join_dataset(dataset_type_name, dataset_search)
return Query(self._driver, tree)

def join_dataset_search(
Expand Down
4 changes: 4 additions & 0 deletions python/lsst/daf/butler/queries/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def materialize(
tree: QueryTree,
dimensions: DimensionGroup,
datasets: frozenset[str],
allow_duplicate_overlaps: bool = False,
) -> MaterializationKey:
"""Execute a query tree, saving results to temporary storage for use
in later queries.
Expand All @@ -222,6 +223,9 @@ def materialize(
datasets : `frozenset` [ `str` ]
Names of dataset types whose ID columns may be materialized. It
is implementation-defined whether they actually are.
allow_duplicate_overlaps : `bool`, optional
If set to `True` then query will be allowed to generate
non-distinct rows for spatial overlaps.
Returns
-------
Expand Down
5 changes: 5 additions & 0 deletions python/lsst/daf/butler/queries/result_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ class ResultSpecBase(pydantic.BaseModel, ABC):
limit: int | None = None
"""Maximum number of rows to return, or `None` for no bound."""

allow_duplicate_overlaps: bool = False
"""If set to True the queries are allowed to returnd duplicate rows for
spatial overlaps.
"""

def validate_tree(self, tree: QueryTree) -> None:
"""Check that this result object is consistent with a query tree.

Expand Down
21 changes: 11 additions & 10 deletions python/lsst/daf/butler/registry/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,20 +246,21 @@ def _transaction(
# PostgreSQL actually considers SET TRANSACTION to be a
# fundamentally different statement from SET (they have their
# own distinct doc pages, at least).
if not (self.isWriteable() or for_temp_tables):
with closing(connection.connection.cursor()) as cursor:
# PostgreSQL permits writing to temporary tables inside
# read-only transactions, but it doesn't permit creating
# them.
with closing(connection.connection.cursor()) as cursor:
if not (self.isWriteable() or for_temp_tables):
cursor.execute("SET TRANSACTION READ ONLY")
cursor.execute("SET TIME ZONE 0")
else:
with closing(connection.connection.cursor()) as cursor:
# Make timestamps UTC, because we didn't use TIMESTAMPZ
# for the column type. When we can tolerate a schema
# change, we should change that type and remove this
# line.
cursor.execute("SET TIME ZONE 0")
# Make timestamps UTC, because we didn't use TIMESTAMPZ
# for the column type. When we can tolerate a schema
# change, we should change that type and remove this
# line.
cursor.execute("SET TIME ZONE 0")
# Using server-side cursors with complex queries frequently
# generates suboptimal query plan, setting
# cursor_tuple_fraction=1 helps for those cases.
cursor.execute("SET cursor_tuple_fraction = 1")
yield is_new, connection

@contextmanager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1635,7 +1635,7 @@ def _get_calibs_table(self, table: DynamicTables) -> sqlalchemy.Table:

def _create_case_expression_for_collections(
collections: Iterable[CollectionRecord], id_column: sqlalchemy.ColumnElement
) -> sqlalchemy.Case | sqlalchemy.Null:
) -> sqlalchemy.ColumnElement:
"""Return a SQLAlchemy Case expression that converts collection IDs to
collection names for the given set of collections.
Expand All @@ -1661,6 +1661,6 @@ def _create_case_expression_for_collections(
# cases, e.g. we start with a list of valid collections but they are
# all filtered out by higher-level code on the basis of collection
# summaries.
return sqlalchemy.null()
return sqlalchemy.cast(sqlalchemy.null(), sqlalchemy.String)

return sqlalchemy.case(mapping, value=id_column)
10 changes: 7 additions & 3 deletions python/lsst/daf/butler/registry/dimensions/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,10 @@ def process_query_overlaps(
predicate: qt.Predicate,
join_operands: Iterable[DimensionGroup],
calibration_dataset_types: Set[str | qt.AnyDatasetType],
allow_duplicates: bool = False,
) -> tuple[qt.Predicate, SqlSelectBuilder, Postprocessing]:
overlaps_visitor = _CommonSkyPixMediatedOverlapsVisitor(
self._db, dimensions, calibration_dataset_types, self._overlap_tables
self._db, dimensions, calibration_dataset_types, self._overlap_tables, allow_duplicates
)
new_predicate = overlaps_visitor.run(predicate, join_operands)
return new_predicate, overlaps_visitor.builder, overlaps_visitor.postprocessing
Expand Down Expand Up @@ -1025,13 +1026,15 @@ def __init__(
dimensions: DimensionGroup,
calibration_dataset_types: Set[str | qt.AnyDatasetType],
overlap_tables: Mapping[str, tuple[sqlalchemy.Table, sqlalchemy.Table]],
allow_duplicates: bool,
):
super().__init__(dimensions, calibration_dataset_types)
self.builder: SqlSelectBuilder = SqlJoinsBuilder(db=db).to_select_builder(qt.ColumnSet(dimensions))
self.postprocessing = Postprocessing()
self.common_skypix = dimensions.universe.commonSkyPix
self.overlap_tables: Mapping[str, tuple[sqlalchemy.Table, sqlalchemy.Table]] = overlap_tables
self.common_skypix_overlaps_done: set[DatabaseDimensionElement] = set()
self.allow_duplicates = allow_duplicates

def visit_spatial_constraint(
self,
Expand Down Expand Up @@ -1081,7 +1084,8 @@ def visit_spatial_constraint(
joins_builder.where(sqlalchemy.or_(*sql_where_or))
self.builder.join(
joins_builder.to_select_builder(
qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(), distinct=True
qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(),
distinct=not self.allow_duplicates,
).into_joins_builder(postprocessing=None)
)
# Short circuit here since the SQL WHERE clause has already
Expand Down Expand Up @@ -1145,7 +1149,7 @@ def visit_spatial_join(
.join(self._make_common_skypix_overlap_joins_builder(b))
.to_select_builder(
qt.ColumnSet(a.minimal_group | b.minimal_group).drop_implied_dimension_keys(),
distinct=True,
distinct=not self.allow_duplicates,
)
.into_joins_builder(postprocessing=None)
)
Expand Down
Loading
Loading