Skip to content

Commit

Permalink
WIP - combine 2 nodes into 1
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Jan 17, 2025
1 parent 49c5ce6 commit 8eb5de3
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 329 deletions.
6 changes: 6 additions & 0 deletions metricflow-semantics/metricflow_semantics/sql/sql_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,12 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
def from_table_and_column_names(table_alias: str, column_name: str) -> SqlColumnReferenceExpression: # noqa: D102
return SqlColumnReferenceExpression.create(SqlColumnReference(table_alias=table_alias, column_name=column_name))

def with_new_table_alias(self, new_table_alias: str) -> SqlColumnReferenceExpression:
"""Returns a new column reference expression with the same column name but a new table alias."""
return SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=new_table_alias, column_name=self.col_ref.column_name
)


@dataclass(frozen=True, eq=False)
class SqlColumnAliasReferenceExpression(SqlExpressionNode):
Expand Down
25 changes: 1 addition & 24 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec
from metricflow_semantics.specs.where_filter.where_filter_spec_set import WhereFilterSpecSet
from metricflow_semantics.specs.where_filter.where_filter_transform import WhereSpecFactory
from metricflow_semantics.sql.sql_exprs import SqlWindowFunction
from metricflow_semantics.sql.sql_join_type import SqlJoinType
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.time.dateutil_adjuster import DateutilTimePeriodAdjuster
Expand Down Expand Up @@ -85,7 +84,6 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand Down Expand Up @@ -1951,29 +1949,8 @@ def _build_custom_offset_time_spine_node(
if {spec.time_granularity for spec in required_time_spine_specs} == {custom_grain}:
# TODO: If querying with only the same grain as is used in the offset_window, can use a simpler plan.
pass
# For custom offset windows queried with other granularities, first, build CustomGranularityBoundsNode.
# This will be used twice in the output node, and ideally will be turned into a CTE.
bounds_node = CustomGranularityBoundsNode.create(
parent_node=time_spine_read_node, custom_granularity_name=custom_grain.name
)
# Build a FilterElementsNode from bounds node to get required unique rows.
bounds_data_set = self._node_data_set_resolver.get_output_data_set(bounds_node)
bounds_specs = tuple(
bounds_data_set.instance_from_window_function(window_func).spec
for window_func in (SqlWindowFunction.FIRST_VALUE, SqlWindowFunction.LAST_VALUE)
)
custom_grain_spec = bounds_data_set.instance_from_time_dimension_grain_and_date_part(
time_granularity_name=custom_grain.name, date_part=None
).spec
filter_elements_node = FilterElementsNode.create(
parent_node=bounds_node,
include_specs=InstanceSpecSet(time_dimension_specs=(custom_grain_spec,) + bounds_specs),
distinct=True,
)
# Pass both the CustomGranularityBoundsNode and the FilterElementsNode into the OffsetByCustomGranularityNode.
return OffsetByCustomGranularityNode.create(
custom_granularity_bounds_node=bounds_node,
filter_elements_node=filter_elements_node,
time_spine_node=time_spine_read_node,
offset_window=offset_window,
required_time_spine_specs=required_time_spine_specs,
)
Expand Down
9 changes: 0 additions & 9 deletions metricflow/dataflow/dataflow_plan_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand Down Expand Up @@ -128,10 +127,6 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
Expand Down Expand Up @@ -235,10 +230,6 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102
return self._default_handler(node)

@override
def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102
return self._default_handler(node)

@override
def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
Expand Down
64 changes: 0 additions & 64 deletions metricflow/dataflow/nodes/custom_granularity_bounds.py

This file was deleted.

33 changes: 8 additions & 25 deletions metricflow/dataflow/nodes/offset_by_custom_granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC
from dataclasses import dataclass
from typing import Optional, Sequence
from typing import Sequence

from dbt_semantic_interfaces.protocols.metric import MetricTimeWindow
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
Expand All @@ -12,36 +12,31 @@

from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode


@dataclass(frozen=True, eq=False)
class OffsetByCustomGranularityNode(DataflowPlanNode, ABC):
"""For a given custom grain, offset its base grain by the requested number of custom grain periods.
Only accepts CustomGranularityBoundsNode as parent node.
Only accepts DataflowPlanNode as parent node.
"""

offset_window: MetricTimeWindow
required_time_spine_specs: Sequence[TimeDimensionSpec]
custom_granularity_bounds_node: CustomGranularityBoundsNode
filter_elements_node: FilterElementsNode
time_spine_node: DataflowPlanNode

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()

@staticmethod
def create( # noqa: D102
custom_granularity_bounds_node: CustomGranularityBoundsNode,
filter_elements_node: FilterElementsNode,
time_spine_node: DataflowPlanNode,
offset_window: MetricTimeWindow,
required_time_spine_specs: Sequence[TimeDimensionSpec],
) -> OffsetByCustomGranularityNode:
return OffsetByCustomGranularityNode(
parent_nodes=(custom_granularity_bounds_node, filter_elements_node),
custom_granularity_bounds_node=custom_granularity_bounds_node,
filter_elements_node=filter_elements_node,
parent_nodes=(time_spine_node,),
time_spine_node=time_spine_node,
offset_window=offset_window,
required_time_spine_specs=required_time_spine_specs,
)
Expand Down Expand Up @@ -74,22 +69,10 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa:
def with_new_parents( # noqa: D102
self, new_parent_nodes: Sequence[DataflowPlanNode]
) -> OffsetByCustomGranularityNode:
custom_granularity_bounds_node: Optional[CustomGranularityBoundsNode] = None
filter_elements_node: Optional[FilterElementsNode] = None
for parent_node in new_parent_nodes:
if isinstance(parent_node, CustomGranularityBoundsNode):
custom_granularity_bounds_node = parent_node
elif isinstance(parent_node, FilterElementsNode):
filter_elements_node = parent_node
assert custom_granularity_bounds_node and filter_elements_node, (
"Can't rewrite OffsetByCustomGranularityNode because the node requires a CustomGranularityBoundsNode and a "
f"FilterElementsNode as parents. Instead, got: {new_parent_nodes}"
)

assert len(new_parent_nodes) == 1
return OffsetByCustomGranularityNode(
parent_nodes=tuple(new_parent_nodes),
custom_granularity_bounds_node=custom_granularity_bounds_node,
filter_elements_node=filter_elements_node,
time_spine_node=new_parent_nodes[0],
offset_window=self.offset_window,
required_time_spine_specs=self.required_time_spine_specs,
)
6 changes: 0 additions & 6 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand Down Expand Up @@ -474,11 +473,6 @@ def visit_join_to_custom_granularity_node( # noqa: D102
def visit_alias_specs_node(self, node: AliasSpecsNode) -> OptimizeBranchResult: # noqa: D102
raise NotImplementedError

def visit_custom_granularity_bounds_node( # noqa: D102
self, node: CustomGranularityBoundsNode
) -> OptimizeBranchResult:
raise NotImplementedError

def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> OptimizeBranchResult:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand Down Expand Up @@ -475,12 +474,6 @@ def visit_alias_specs_node(self, node: AliasSpecsNode) -> ComputeMetricsBranchCo
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_custom_granularity_bounds_node( # noqa: D102
self, node: CustomGranularityBoundsNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> ComputeMetricsBranchCombinerResult:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand Down Expand Up @@ -366,12 +365,6 @@ def visit_alias_specs_node(self, node: AliasSpecsNode) -> OptimizeBranchResult:
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_custom_granularity_bounds_node( # noqa: D102
self, node: CustomGranularityBoundsNode
) -> OptimizeBranchResult:
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> OptimizeBranchResult:
Expand Down
5 changes: 0 additions & 5 deletions metricflow/execution/dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand Down Expand Up @@ -208,10 +207,6 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
def visit_alias_specs_node(self, node: AliasSpecsNode) -> ConvertToExecutionPlanResult:
raise NotImplementedError

@override
def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> ConvertToExecutionPlanResult:
raise NotImplementedError

@override
def visit_offset_by_custom_granularity_node(
self, node: OffsetByCustomGranularityNode
Expand Down
Loading

0 comments on commit 8eb5de3

Please sign in to comment.