Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add support for incremental rewrite of nested queries #400

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.incremental;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.apache.calcite.rel.RelNode;


public class IncrementalTransformerResults {

private RelNode incrementalRelNode;
private RelNode refreshRelNode;
private Map<String, RelNode> intermediateQueryRelNodes;
private List<String> intermediateOrderings;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed does it make sense to get rid of incrementalRelNode? Is it a special case of Map<String, RelNode> intermediateQueryRelNodes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intermediateOrderings --> materializationOrder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm now getting indices from the map so will get rid of this entirely.


public IncrementalTransformerResults() {
incrementalRelNode = null;
refreshRelNode = null;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be a list of refresh relnodes? We can remove from current PR till we converge on the best representation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this should be a list of refresh RelNodes—one for each aggregate query. Will remove for now.

intermediateQueryRelNodes = new LinkedHashMap<>();
intermediateOrderings = new ArrayList<>();
}

public boolean existsIncrementalRelNode() {
return incrementalRelNode != null;
}

public RelNode getIncrementalRelNode() {
return incrementalRelNode;
}

public boolean existsRefreshRelNode() {
return refreshRelNode != null;
}

public RelNode getRefreshRelNode() {
return refreshRelNode;
}

public Map<String, RelNode> getIntermediateQueryRelNodes() {
return intermediateQueryRelNodes;
}

public boolean containsIntermediateQueryRelNodeKey(String name) {
return intermediateQueryRelNodes.containsKey(name);
}

public List<String> getIntermediateOrderings() {
return intermediateOrderings;
}

public int getIndexOfIntermediateOrdering(String name) {
return intermediateOrderings.indexOf(name);
}

public void setIncrementalRelNode(RelNode incrementalRelNode) {
this.incrementalRelNode = incrementalRelNode;
}

public void setRefreshRelNode(RelNode refreshRelNode) {
this.refreshRelNode = refreshRelNode;
}

public void addIntermediateQueryRelNode(String name, RelNode intermediateRelNode) {
this.intermediateQueryRelNodes.put(name, intermediateRelNode);
addIntermediateOrdering(name);
}

public void addMultipleIntermediateQueryRelNodes(Map<String, RelNode> intermediateQueryRelNodes) {
if (intermediateQueryRelNodes != null) {
this.intermediateQueryRelNodes.putAll(intermediateQueryRelNodes);
addMultipleIntermediateOrderings(new ArrayList<>(intermediateQueryRelNodes.keySet()));
}
}

public void addIntermediateOrdering(String intermediateOrdering) {
this.intermediateOrderings.add(intermediateOrdering);
}

public void addMultipleIntermediateOrderings(List<String> intermediateOrderings) {
this.intermediateOrderings.addAll(intermediateOrderings);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.apache.calcite.plan.RelOptSchema;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.prepare.RelOptTableImpl;
import org.apache.calcite.rel.RelNode;
Expand All @@ -29,14 +31,29 @@

public class RelNodeIncrementalTransformer {

private static RelOptSchema relOptSchema;

private RelNodeIncrementalTransformer() {
}

public static RelNode convertRelIncremental(RelNode originalNode) {
public static IncrementalTransformerResults performIncrementalTransformation(RelNode originalNode) {
IncrementalTransformerResults incrementalTransformerResults = convertRelIncremental(originalNode);
return incrementalTransformerResults;
}

private static IncrementalTransformerResults convertRelIncremental(RelNode originalNode) {
IncrementalTransformerResults incrementalTransformerResults = new IncrementalTransformerResults();
RelShuttle converter = new RelShuttleImpl() {
@Override
public RelNode visit(TableScan scan) {
RelOptTable originalTable = scan.getTable();

// Set relOptSchema
if (relOptSchema == null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the use case for this check? when will relOptSchema not be present in the table?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is to set relOptSchema only when it has not previously already been set. I don't think it's absolutely necessary but I just added a comment to explain the conditional's purpose.

relOptSchema = originalTable.getRelOptSchema();
}

// Create delta scan
List<String> incrementalNames = new ArrayList<>(originalTable.getQualifiedName());
String deltaTableName = incrementalNames.remove(incrementalNames.size() - 1) + "_delta";
incrementalNames.add(deltaTableName);
Expand All @@ -49,61 +66,130 @@ public RelNode visit(TableScan scan) {
public RelNode visit(LogicalJoin join) {
RelNode left = join.getLeft();
RelNode right = join.getRight();
RelNode incrementalLeft = convertRelIncremental(left);
RelNode incrementalRight = convertRelIncremental(right);
IncrementalTransformerResults incrementalTransformerResultsLeft = convertRelIncremental(left);
IncrementalTransformerResults incrementalTransformerResultsRight = convertRelIncremental(right);
RelNode incrementalLeft = incrementalTransformerResultsLeft.getIncrementalRelNode();
RelNode incrementalRight = incrementalTransformerResultsRight.getIncrementalRelNode();
incrementalTransformerResults
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will it simplify the logic if we make intermediateQueryRelNodes global?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we are actually getting rid of IncrementalTransformerResults and keeping all logic inside of the transformer itself (larger discussion in the Slack channel).

.addMultipleIntermediateQueryRelNodes(incrementalTransformerResultsLeft.getIntermediateQueryRelNodes());
incrementalTransformerResults
.addMultipleIntermediateQueryRelNodes(incrementalTransformerResultsRight.getIntermediateQueryRelNodes());

RexBuilder rexBuilder = join.getCluster().getRexBuilder();

// Check if we can replace the left and right nodes with a scan of a materialized table
if (incrementalTransformerResults.containsIntermediateQueryRelNodeKey(getTableNameFromDescription(left))) {
String description = getTableNameFromDescription(left);
String deterministicDescription =
"Table#" + incrementalTransformerResults.getIndexOfIntermediateOrdering(description);
LogicalProject leftLastProject =
createReplacementProjectNodeForGivenRelNode(deterministicDescription, left, rexBuilder);
left = leftLastProject;
LogicalProject leftDeltaProject = createReplacementProjectNodeForGivenRelNode(
deterministicDescription + "_delta", incrementalLeft, rexBuilder);
incrementalLeft = leftDeltaProject;
}
if (incrementalTransformerResults.containsIntermediateQueryRelNodeKey(getTableNameFromDescription(right))) {
String description = getTableNameFromDescription(right);
String deterministicDescription =
"Table#" + incrementalTransformerResults.getIndexOfIntermediateOrdering(description);
LogicalProject rightLastProject =
createReplacementProjectNodeForGivenRelNode(deterministicDescription, right, rexBuilder);
right = rightLastProject;
LogicalProject rightDeltaProject = createReplacementProjectNodeForGivenRelNode(
deterministicDescription + "_delta", incrementalRight, rexBuilder);
incrementalRight = rightDeltaProject;
}

LogicalProject p1 = createProjectOverJoin(join, left, incrementalRight, rexBuilder);
LogicalProject p2 = createProjectOverJoin(join, incrementalLeft, right, rexBuilder);
LogicalProject p3 = createProjectOverJoin(join, incrementalLeft, incrementalRight, rexBuilder);

LogicalUnion unionAllJoins =
LogicalUnion.create(Arrays.asList(LogicalUnion.create(Arrays.asList(p1, p2), true), p3), true);

return unionAllJoins;
}

@Override
public RelNode visit(LogicalFilter filter) {
RelNode transformedChild = convertRelIncremental(filter.getInput());
IncrementalTransformerResults incrementalTransformerResultsChild = convertRelIncremental(filter.getInput());
RelNode transformedChild = incrementalTransformerResultsChild.getIncrementalRelNode();
incrementalTransformerResults
.addMultipleIntermediateQueryRelNodes(incrementalTransformerResultsChild.getIntermediateQueryRelNodes());
return LogicalFilter.create(transformedChild, filter.getCondition());
}

@Override
public RelNode visit(LogicalProject project) {
RelNode transformedChild = convertRelIncremental(project.getInput());
return LogicalProject.create(transformedChild, project.getProjects(), project.getRowType());
IncrementalTransformerResults incrementalTransformerResultsChild = convertRelIncremental(project.getInput());
RelNode transformedChild = incrementalTransformerResultsChild.getIncrementalRelNode();
incrementalTransformerResults
.addMultipleIntermediateQueryRelNodes(incrementalTransformerResultsChild.getIntermediateQueryRelNodes());
incrementalTransformerResults.addIntermediateQueryRelNode(getTableNameFromDescription(project), project);
LogicalProject transformedProject =
LogicalProject.create(transformedChild, project.getProjects(), project.getRowType());
incrementalTransformerResults.addIntermediateQueryRelNode(getTableNameFromDescription(project) + "_delta",
transformedProject);
return transformedProject;
}

@Override
public RelNode visit(LogicalUnion union) {
List<RelNode> children = union.getInputs();
List<RelNode> transformedChildren =
List<IncrementalTransformerResults> incrementalTransformerResultsChildren =
children.stream().map(child -> convertRelIncremental(child)).collect(Collectors.toList());
List<RelNode> transformedChildren = new ArrayList<>();
for (IncrementalTransformerResults incrementalTransformerResultsChild : incrementalTransformerResultsChildren) {
transformedChildren.add(incrementalTransformerResultsChild.getIncrementalRelNode());
incrementalTransformerResults
.addMultipleIntermediateQueryRelNodes(incrementalTransformerResultsChild.getIntermediateQueryRelNodes());
}
return LogicalUnion.create(transformedChildren, union.all);
}

@Override
public RelNode visit(LogicalAggregate aggregate) {
RelNode transformedChild = convertRelIncremental(aggregate.getInput());
IncrementalTransformerResults incrementalTransformerResultsChild = convertRelIncremental(aggregate.getInput());
RelNode transformedChild = incrementalTransformerResultsChild.getIncrementalRelNode();
incrementalTransformerResults
.addMultipleIntermediateQueryRelNodes(incrementalTransformerResultsChild.getIntermediateQueryRelNodes());
return LogicalAggregate.create(transformedChild, aggregate.getGroupSet(), aggregate.getGroupSets(),
aggregate.getAggCallList());
}
};
return originalNode.accept(converter);
incrementalTransformerResults.setIncrementalRelNode(originalNode.accept(converter));
return incrementalTransformerResults;
}

private static LogicalProject createProjectOverJoin(LogicalJoin join, RelNode left, RelNode right,
private static String getTableNameFromDescription(RelNode relNode) {
String identifier = relNode.getDescription().split("#")[1];
return "Table#" + identifier;
}

private static LogicalProject createReplacementProjectNodeForGivenRelNode(String relOptTableName, RelNode relNode,
RexBuilder rexBuilder) {
LogicalJoin incrementalJoin =
LogicalJoin.create(left, right, join.getCondition(), join.getVariablesSet(), join.getJoinType());
RelOptTable table =
RelOptTableImpl.create(relOptSchema, relNode.getRowType(), Collections.singletonList(relOptTableName), null);
TableScan scan = LogicalTableScan.create(relNode.getCluster(), table);
return createProjectOverNode(scan, rexBuilder);
}

private static LogicalProject createProjectOverNode(RelNode relNode, RexBuilder rexBuilder) {
ArrayList<RexNode> projects = new ArrayList<>();
ArrayList<String> names = new ArrayList<>();
IntStream.range(0, incrementalJoin.getRowType().getFieldList().size()).forEach(i -> {
projects.add(rexBuilder.makeInputRef(incrementalJoin, i));
names.add(incrementalJoin.getRowType().getFieldNames().get(i));
IntStream.range(0, relNode.getRowType().getFieldList().size()).forEach(i -> {
projects.add(rexBuilder.makeInputRef(relNode, i));
names.add(relNode.getRowType().getFieldNames().get(i));
});
return LogicalProject.create(incrementalJoin, projects, names);
return LogicalProject.create(relNode, projects, names);
}

private static LogicalProject createProjectOverJoin(LogicalJoin join, RelNode left, RelNode right,
RexBuilder rexBuilder) {
LogicalJoin incrementalJoin =
LogicalJoin.create(left, right, join.getCondition(), join.getVariablesSet(), join.getJoinType());
return createProjectOverNode(incrementalJoin, rexBuilder);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ public void afterClass() throws IOException {
}

public String convert(RelNode relNode) {
RelNode incrementalRelNode = RelNodeIncrementalTransformer.convertRelIncremental(relNode);
IncrementalTransformerResults incrementalTransformerResults =
RelNodeIncrementalTransformer.performIncrementalTransformation(relNode);
RelNode incrementalRelNode = incrementalTransformerResults.getIncrementalRelNode();
CoralRelToSqlNodeConverter converter = new CoralRelToSqlNodeConverter();
SqlNode sqlNode = converter.convert(incrementalRelNode);
return sqlNode.toSqlString(converter.INSTANCE).getSql();
Expand Down Expand Up @@ -81,41 +83,6 @@ public void testJoinWithFilter() {
assertEquals(getIncrementalModification(sql), expected);
}

@Test
public void testJoinWithNestedFilter() {
String sql =
"WITH tmp AS (SELECT * from test.bar1 WHERE test.bar1.x > 10), tmp2 AS (SELECT * from test.bar2) SELECT * FROM tmp JOIN tmp2 ON tmp.x = tmp2.x";
String expected = "SELECT *\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM test.bar1 AS bar1\n"
+ "WHERE bar1.x > 10) AS t\n" + "INNER JOIN test.bar2_delta AS bar2_delta ON t.x = bar2_delta.x\n"
+ "UNION ALL\n" + "SELECT *\n" + "FROM (SELECT *\n" + "FROM test.bar1_delta AS bar1_delta\n"
+ "WHERE bar1_delta.x > 10) AS t0\n" + "INNER JOIN test.bar2 AS bar2 ON t0.x = bar2.x) AS t1\n" + "UNION ALL\n"
+ "SELECT *\n" + "FROM (SELECT *\n" + "FROM test.bar1_delta AS bar1_delta0\n"
+ "WHERE bar1_delta0.x > 10) AS t2\n" + "INNER JOIN test.bar2_delta AS bar2_delta0 ON t2.x = bar2_delta0.x";
assertEquals(getIncrementalModification(sql), expected);
}

@Test
public void testNestedJoin() {
String sql =
"WITH tmp AS (SELECT * FROM test.bar1 INNER JOIN test.bar2 ON test.bar1.x = test.bar2.x) SELECT * FROM tmp INNER JOIN test.bar3 ON tmp.x = test.bar3.x";
String expected = "SELECT *\n" + "FROM (SELECT *\n" + "FROM test.bar1 AS bar1\n"
+ "INNER JOIN test.bar2 AS bar2 ON bar1.x = bar2.x\n"
+ "INNER JOIN test.bar3_delta AS bar3_delta ON bar1.x = bar3_delta.x\n" + "UNION ALL\n" + "SELECT *\n"
+ "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM test.bar1 AS bar10\n"
+ "INNER JOIN test.bar2_delta AS bar2_delta ON bar10.x = bar2_delta.x\n" + "UNION ALL\n" + "SELECT *\n"
+ "FROM test.bar1_delta AS bar1_delta\n" + "INNER JOIN test.bar2 AS bar20 ON bar1_delta.x = bar20.x) AS t\n"
+ "UNION ALL\n" + "SELECT *\n" + "FROM test.bar1_delta AS bar1_delta0\n"
+ "INNER JOIN test.bar2_delta AS bar2_delta0 ON bar1_delta0.x = bar2_delta0.x) AS t0\n"
+ "INNER JOIN test.bar3 AS bar3 ON t0.x = bar3.x) AS t1\n" + "UNION ALL\n" + "SELECT *\n" + "FROM (SELECT *\n"
+ "FROM (SELECT *\n" + "FROM test.bar1 AS bar11\n"
+ "INNER JOIN test.bar2_delta AS bar2_delta1 ON bar11.x = bar2_delta1.x\n" + "UNION ALL\n" + "SELECT *\n"
+ "FROM test.bar1_delta AS bar1_delta1\n" + "INNER JOIN test.bar2 AS bar21 ON bar1_delta1.x = bar21.x) AS t2\n"
+ "UNION ALL\n" + "SELECT *\n" + "FROM test.bar1_delta AS bar1_delta2\n"
+ "INNER JOIN test.bar2_delta AS bar2_delta2 ON bar1_delta2.x = bar2_delta2.x) AS t3\n"
+ "INNER JOIN test.bar3_delta AS bar3_delta0 ON t3.x = bar3_delta0.x";
assertEquals(getIncrementalModification(sql), expected);
}

@Test
public void testUnion() {
String sql = "SELECT * FROM test.bar1 UNION SELECT * FROM test.bar2 UNION SELECT * FROM test.bar3";
Expand Down Expand Up @@ -143,4 +110,31 @@ public void testSelectSpecificJoin() {
+ "INNER JOIN test.bar2_delta AS bar2_delta0 ON bar1_delta0.x = bar2_delta0.x) AS t0";
assertEquals(getIncrementalModification(sql), expected);
}

@Test
public void testNestedJoin() {
String nestedJoin = "SELECT a1, a2 FROM test.alpha JOIN test.beta ON test.alpha.a1 = test.beta.b1";
String sql = "SELECT a2, g1 FROM (" + nestedJoin + ") AS nj JOIN test.gamma ON nj.a2 = test.gamma.g2";
String expected = "SELECT t0.a2, t0.g1\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM Table#0 AS Table#0\n"
+ "INNER JOIN test.gamma_delta AS gamma_delta ON Table#0.a2 = gamma_delta.g2\n" + "UNION ALL\n" + "SELECT *\n"
+ "FROM Table#0_delta AS Table#0_delta\n"
+ "INNER JOIN test.gamma AS gamma ON Table#0_delta.a2 = gamma.g2) AS t\n" + "UNION ALL\n" + "SELECT *\n"
+ "FROM Table#0_delta AS Table#0_delta0\n"
+ "INNER JOIN test.gamma_delta AS gamma_delta0 ON Table#0_delta0.a2 = gamma_delta0.g2) AS t0";
assertEquals(getIncrementalModification(sql), expected);
}

@Test
public void testThreeNestedJoins() {
String nestedJoin1 = "SELECT a1, a2 FROM test.alpha JOIN test.beta ON test.alpha.a1 = test.beta.b1";
String nestedJoin2 = "SELECT a2, g1 FROM (" + nestedJoin1 + ") AS nj1 JOIN test.gamma ON nj1.a2 = test.gamma.g2";
String sql = "SELECT g1, e2 FROM (" + nestedJoin2 + ") AS nj2 JOIN test.epsilon ON nj2.g1 = test.epsilon.e1";
String expected = "SELECT t0.g1, t0.e2\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM Table#2 AS Table#2\n"
+ "INNER JOIN test.epsilon_delta AS epsilon_delta ON Table#2.g1 = epsilon_delta.e1\n" + "UNION ALL\n"
+ "SELECT *\n" + "FROM Table#2_delta AS Table#2_delta\n"
+ "INNER JOIN test.epsilon AS epsilon ON Table#2_delta.g1 = epsilon.e1) AS t\n" + "UNION ALL\n" + "SELECT *\n"
+ "FROM Table#2_delta AS Table#2_delta0\n"
+ "INNER JOIN test.epsilon_delta AS epsilon_delta0 ON Table#2_delta0.g1 = epsilon_delta0.e1) AS t0";
assertEquals(getIncrementalModification(sql), expected);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ public static void initializeViews(HiveConf conf) throws HiveException, MetaExce
run(driver, "CREATE TABLE IF NOT EXISTS test.bar1(x int, y double)");
run(driver, "CREATE TABLE IF NOT EXISTS test.bar2(x int, y double)");
run(driver, "CREATE TABLE IF NOT EXISTS test.bar3(x int, y double)");

run(driver, "CREATE TABLE IF NOT EXISTS test.alpha(a1 int, a2 double)");
run(driver, "CREATE TABLE IF NOT EXISTS test.beta(b1 int, b2 double)");
run(driver, "CREATE TABLE IF NOT EXISTS test.gamma(g1 int, g2 double)");
run(driver, "CREATE TABLE IF NOT EXISTS test.epsilon(e1 int, e2 double)");
}

public static HiveConf loadResourceHiveConf() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.apache.calcite.rel.RelNode;

import com.linkedin.coral.hive.hive2rel.HiveToRelConverter;
import com.linkedin.coral.incremental.IncrementalTransformerResults;
import com.linkedin.coral.incremental.RelNodeIncrementalTransformer;
import com.linkedin.coral.spark.CoralSpark;

Expand All @@ -18,7 +19,9 @@ public class IncrementalUtils {

public static String getSparkIncrementalQueryFromUserSql(String query) {
RelNode originalNode = new HiveToRelConverter(hiveMetastoreClient).convertSql(query);
RelNode incrementalRelNode = RelNodeIncrementalTransformer.convertRelIncremental(originalNode);
IncrementalTransformerResults incrementalTransformerResults =
RelNodeIncrementalTransformer.performIncrementalTransformation(originalNode);
RelNode incrementalRelNode = incrementalTransformerResults.getIncrementalRelNode();
CoralSpark coralSpark = CoralSpark.create(incrementalRelNode);
return coralSpark.getSparkSql();
}
Expand Down