Skip to content

Commit

Permalink
add capabilities for a full subquery filter (#247)
Browse files Browse the repository at this point in the history
* add capabilities for a full subquery filter

* fix query type issue

* addtl fixes for serialization and mapping only use cases

* fix issue with mapping in sub field

* fix recursive cte sql error on nested subquery filter

* bump version to 0.12.48
  • Loading branch information
pblankley authored Dec 19, 2024
1 parent 6ed8634 commit e3c5d85
Show file tree
Hide file tree
Showing 8 changed files with 761 additions and 41 deletions.
6 changes: 6 additions & 0 deletions metrics_layer/core/model/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,12 @@ class MetricsLayerFilterExpressionType(str, Enum):
IsNotNull = "is_not_null"
IsIn = "isin"
IsNotIn = "isnotin"
IsInQuery = "is_in_query"
IsNotInQuery = "is_not_in_query"
BooleanTrue = "boolean_true"
BooleanFalse = "boolean_false"
IsTrue = "is_true"
IsFalse = "is_false"
Matches = "matches"

def __hash__(self):
Expand Down Expand Up @@ -575,6 +579,8 @@ def sql_query(sql_to_compare: str, expression_type: str, value, field_datatype:
MetricsLayerFilterExpressionType.IsNotIn: lambda f: f.isin(value).negate(),
MetricsLayerFilterExpressionType.BooleanTrue: lambda f: LiteralValueCriterion(f),
MetricsLayerFilterExpressionType.BooleanFalse: lambda f: f.negate(),
MetricsLayerFilterExpressionType.IsTrue: lambda f: LiteralValueCriterion(f),
MetricsLayerFilterExpressionType.IsFalse: lambda f: f.negate(),
}

try:
Expand Down
77 changes: 52 additions & 25 deletions metrics_layer/core/sql/query_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(
else:
self.query_type = definition["query_type"]
self.filter_type = filter_type
self._extra_group_by_filter_conditions = []

self.validate(definition)

Expand All @@ -80,7 +79,10 @@ def conditions(self):

@property
def is_group_by(self):
return self.group_by is not None
return self.group_by is not None or self.expression in {
MetricsLayerFilterExpressionType.IsInQuery.value,
MetricsLayerFilterExpressionType.IsNotInQuery.value,
}

@property
def is_funnel(self):
Expand Down Expand Up @@ -243,6 +245,33 @@ def _handle_cte_alias_replacement(

def isin_sql_query(self):
cte_alias = self.group_by_filter_cte_lookup[hash(self)]
if self.group_by:
return self._create_legacy_group_by_is_in_query(cte_alias)
else:
return self._create_is_in_query(cte_alias)

def _create_is_in_query(self, cte_alias):
connection_field_id = self.value["field"]
connection_field = self.design.get_field(connection_field_id)
base = query_lookup[self.query_type]
subquery = base.from_(Table(cte_alias)).select(connection_field.alias(with_view=True)).distinct()
if self.expression == MetricsLayerFilterExpressionType.IsNotInQuery.value:
expression = MetricsLayerFilterExpressionType.IsNotIn.value
elif self.expression == MetricsLayerFilterExpressionType.IsInQuery.value:
expression = MetricsLayerFilterExpressionType.IsIn.value
else:
raise QueryError(f"Invalid expression for subquery filter: {self.expression}")

definition = {
"query_type": self.query_type,
"field": self.field.id(),
"expression": expression,
"value": subquery,
}
f = MetricsLayerFilter(definition=definition, design=None, filter_type="where")
return f.criterion(self.field.sql_query(self.query_type))

def _create_legacy_group_by_is_in_query(self, cte_alias):
group_by_field = self.design.get_field(self.group_by)
base = query_lookup[self.query_type]
subquery = base.from_(Table(cte_alias)).select(group_by_field.alias(with_view=True)).distinct()
Expand Down Expand Up @@ -292,33 +321,31 @@ def criterion(self, field_sql: str) -> Criterion:
field_datatype = "unknown"
return Filter.sql_query(field_sql, self.expression_type, self.value, field_datatype)

def consolidate_group_by_filter(self, filter_class_to_consolidate: "MetricsLayerFilter") -> None:
"""
Consolidate a group_by filter with another filter
"""
if not self.is_group_by:
raise QueryError("A group_by filter is invalid for a filter with no group_by property")

if self.group_by != filter_class_to_consolidate.group_by:
raise QueryError("The group_by field must be the same for both filters")

joinable_graphs = [jg for jg in self.field.join_graphs() if "merged_result" not in jg]
consolidate_joinable_graphs = [
jg for jg in filter_class_to_consolidate.field.join_graphs() if "merged_result" not in jg
]
join_overlap = set.intersection(*map(set, [joinable_graphs, consolidate_joinable_graphs]))
if len(join_overlap) == 0:
raise QueryError("The filters must have a join path in common to be consolidated")

self._extra_group_by_filter_conditions.append(filter_class_to_consolidate)

def cte(self, query_class, design_class):
if not self.is_group_by:
raise QueryError("A CTE is invalid for a filter with no group_by property")
raise QueryError(
"A CTE is invalid for a filter with no group_by property or is_in_query/is_not_in_query"
" expression"
)
if self.group_by:
return self._create_subquery_from_group_by_property(query_class, design_class)
elif self.expression in {
MetricsLayerFilterExpressionType.IsInQuery.value,
MetricsLayerFilterExpressionType.IsNotInQuery.value,
}:
return self._create_subquery_from_query_property()
else:
raise QueryError(
"A CTE is invalid for a filter with no group_by property or is_in_query/is_not_in_query"
" expression"
)

def _create_subquery_from_query_property(self):
# This is a subquery that's compiled in the `resolve.py` file in the initial parsing step.
return self.value["sql_query"]

def _create_subquery_from_group_by_property(self, query_class, design_class):
group_by_filters = [{k: v for k, v in self._definition.items() if k != "group_by"}]
for f in self._extra_group_by_filter_conditions:
group_by_filters.append({k: v for k, v in f._definition.items() if k != "group_by"})

field_lookup = {}
group_by_field = self.design.get_field(self.group_by)
Expand Down
21 changes: 15 additions & 6 deletions metrics_layer/core/sql/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def parse_definition(self, definition: dict):
access_filter_literal, _ = self.design.get_access_filter()
if where or access_filter_literal:
wheres, group_by_wheres, group_by_where_cte_lookup = self._parse_filter_object(
where, "where", access_filter=access_filter_literal
where,
"where",
access_filter=access_filter_literal,
nesting_depth=definition.get("nesting_depth", 0),
)
self.where_filters.extend([f for f in wheres if not f.is_funnel])
self.funnel_filters.extend([f for f in wheres if f.is_funnel])
Expand Down Expand Up @@ -102,7 +105,9 @@ def parse_definition(self, definition: dict):
}
)

def _parse_filter_object(self, filter_object, filter_type: str, access_filter: str = None):
def _parse_filter_object(
self, filter_object, filter_type: str, access_filter: str = None, nesting_depth: int = 0
):
results, group_by_results, group_by_cte_lookup = [], [], {}
extra_kwargs = dict(filter_type=filter_type, design=self.design)

Expand All @@ -121,10 +126,14 @@ def _parse_filter_object(self, filter_object, filter_type: str, access_filter: s
for filter_dict in filter_object:
flattened_filters = flatten_filters(filter_dict)
for sub_filter in flattened_filters:
if "group_by" in sub_filter:
gb_f = MetricsLayerFilter(definition=sub_filter, **extra_kwargs)
group_by_cte_lookup[hash(gb_f)] = f"filter_subquery_{cte_counter}"
group_by_results.append(gb_f)
f = MetricsLayerFilter(definition=sub_filter, **extra_kwargs)
if f.is_group_by:
if nesting_depth > 0:
cte_alias = f"filter_subquery_{nesting_depth}_{cte_counter}"
else:
cte_alias = f"filter_subquery_{cte_counter}"
group_by_cte_lookup[hash(f)] = cte_alias
group_by_results.append(f)
cte_counter += 1

for filter_dict in filter_object:
Expand Down
56 changes: 54 additions & 2 deletions metrics_layer/core/sql/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Union

from metrics_layer.core.exceptions import JoinError, QueryError
from metrics_layer.core.model.filter import Filter
from metrics_layer.core.model.filter import Filter, MetricsLayerFilterExpressionType
from metrics_layer.core.model.project import Project
from metrics_layer.core.sql.merged_query_resolve import MergedSQLQueryResolver
from metrics_layer.core.sql.query_base import QueryKindTypes
Expand Down Expand Up @@ -314,7 +314,9 @@ def _replace_field_value_in_group_by_filter(self):
optimal_join_graph_connection = [
o for o in optimal_join_graph_connection if "merged_result" not in o
]
flattened_conditions = SingleSQLQueryResolver.flatten_filters(self.where)
flattened_conditions = SingleSQLQueryResolver.flatten_filters(
self.where, return_nesting_depth=True
)
for cond in flattened_conditions:
if "group_by" in cond:
# Only the group by field needs to be joinable or merge-able to the query
Expand All @@ -332,6 +334,55 @@ def _replace_field_value_in_group_by_filter(self):
)
self.field_id_mapping[cond["field"]] = replace_with.id()
cond["field"] = replace_with.id()
elif cond["expression"] in {
MetricsLayerFilterExpressionType.IsInQuery.value,
MetricsLayerFilterExpressionType.IsNotInQuery.value,
}:
defaults = {
"project": self.project,
"connections": self.connections,
"model_name": self.model.name,
"return_pypika_query": False,
}
if "query_type" in self.kwargs:
defaults["query_type"] = self.kwargs["query_type"]

# This handles the case where the passed filter is incomplete, and
# does not apply the filter
if "query" not in cond["value"]:
continue

if "query" in cond["value"] and not isinstance(cond["value"]["query"], dict):
raise QueryError(
"Subquery filter value for the key 'query' must be a dictionary. It was"
f" {cond['value']['query']}"
)

if "apply_limit" in cond["value"] and not bool(cond["value"]["apply_limit"]):
cond["value"]["query"]["limit"] = None

if "nesting_depth" in cond and cond["nesting_depth"] > 0:
defaults["nesting_depth"] = cond["nesting_depth"]

resolver = SQLQueryResolver(**cond["value"]["query"], **defaults)
jg_connection = set.intersection(*map(set, resolver.field_lookup.values()))
optimal_jg_connection = [o for o in jg_connection if "merged_result" not in o]

mapped_field = self.project.get_mapped_field(cond["value"]["field"], model=self.model)
if mapped_field:
field = self.determine_field_to_replace_with(
mapped_field, optimal_jg_connection, jg_connection
)
self.field_id_mapping[cond["value"]["field"]] = field.id()
cond["value"]["field"] = field.id()
else:
field = self.project.get_field(cond["value"]["field"])
if field.id() not in {self.project.get_field(d).id() for d in resolver.dimensions}:
raise QueryError(
f"Field {field.id()} not found in subquery dimensions {resolver.dimensions}. You"
" must specify a dimension that is present in the subquery."
)
cond["value"]["sql_query"] = resolver.get_query(semicolon=False)

def _get_field_from_lookup(self, field_name: str, only_search_lookup: bool = False):
if field_name in self.field_object_lookup:
Expand Down Expand Up @@ -538,6 +589,7 @@ def _deduplicate_always_where_filters(filters: list):
def _clean_conditional_filter_syntax(self, filters: Union[str, None, List]):
if not filters or isinstance(filters, str):
return filters

if isinstance(filters, dict):
return [filters]

Expand Down
6 changes: 4 additions & 2 deletions metrics_layer/core/sql/single_query_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
self.funnel, self.is_funnel_query = self.parse_funnel(funnel)
self.parse_field_names(where, having, order_by)
self.model = model
self.nesting_depth = kwargs.get("nesting_depth", 0)
self.query_type = kwargs.get("query_type")
if self.query_type is None:
raise QueryError(
Expand Down Expand Up @@ -64,6 +65,7 @@ def get_query(self, semicolon: bool = True):
"select_raw_sql": self.select_raw_sql,
"limit": self.limit,
"return_pypika_query": self.return_pypika_query,
"nesting_depth": self.nesting_depth,
}
if self.has_cumulative_metric and self.is_funnel_query:
raise QueryError("Cumulative metrics cannot be used with funnel queries")
Expand Down Expand Up @@ -242,8 +244,8 @@ def parse_identifiers_from_dicts(conditions: list):
raise QueryError(f"Identifier was missing required 'field' key: {cond}")

@staticmethod
def flatten_filters(filters: list):
return flatten_filters(filters)
def flatten_filters(filters: list, return_nesting_depth: bool = False):
return flatten_filters(filters, return_nesting_depth=return_nesting_depth)

@staticmethod
def _check_for_dict(conditions: list):
Expand Down
16 changes: 11 additions & 5 deletions metrics_layer/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,25 @@ def generate_random_password(length):
return result_str


def flatten_filters(filters: list):
def flatten_filters(filters: list, return_nesting_depth: bool = False):
nesting_depth = 0
flat_list = []

def recurse(filter_obj):
def recurse(filter_obj, return_nesting_depth: bool):
nonlocal nesting_depth
if isinstance(filter_obj, dict):
if "conditions" in filter_obj:
nesting_depth += 1
for f in filter_obj["conditions"]:
recurse(f)
recurse(f, return_nesting_depth)
else:
if return_nesting_depth:
filter_obj["nesting_depth"] = nesting_depth
flat_list.append(filter_obj)
elif isinstance(filter_obj, list):
nesting_depth += 1
for item in filter_obj:
recurse(item)
recurse(item, return_nesting_depth)

recurse(filters)
recurse(filters, return_nesting_depth=return_nesting_depth)
return flat_list
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "metrics_layer"
version = "0.12.47"
version = "0.12.48"
description = "The open source metrics layer."
authors = ["Paul Blankley <[email protected]>"]
keywords = ["Metrics Layer", "Business Intelligence", "Analytics"]
Expand Down
Loading

0 comments on commit e3c5d85

Please sign in to comment.