Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add support for incremental rewrite of nested queries #400

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.apache.calcite.plan.RelOptSchema;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.prepare.RelOptTableImpl;
import org.apache.calcite.rel.RelNode;
Expand All @@ -29,16 +33,63 @@

public class RelNodeIncrementalTransformer {

private RelNodeIncrementalTransformer() {
private final String TABLE_NAME_PREFIX = "Table#";
private final String DELTA_SUFFIX = "_delta";

private RelOptSchema relOptSchema;
private Map<String, RelNode> snapshotRelNodes;
private Map<String, RelNode> deltaRelNodes;
private RelNode tempLastRelNode;

public RelNodeIncrementalTransformer() {
relOptSchema = null;
snapshotRelNodes = new LinkedHashMap<>();
deltaRelNodes = new LinkedHashMap<>();
tempLastRelNode = null;
}

public static RelNode convertRelIncremental(RelNode originalNode) {
/**
* Returns snapshotRelNodes with deterministic keys.
*/
public Map<String, RelNode> getSnapshotRelNodes() {
Map<String, RelNode> deterministicSnapshotRelNodes = new LinkedHashMap<>();
for (String description : snapshotRelNodes.keySet()) {
deterministicSnapshotRelNodes.put(getDeterministicDescriptionFromDescription(description, false),
snapshotRelNodes.get(description));
}
return deterministicSnapshotRelNodes;
}

/**
* Returns deltaRelNodes with deterministic keys.
*/
public Map<String, RelNode> getDeltaRelNodes() {
Map<String, RelNode> deterministicDeltaRelNodes = new LinkedHashMap<>();
for (String description : deltaRelNodes.keySet()) {
deterministicDeltaRelNodes.put(getDeterministicDescriptionFromDescription(description, true),
deltaRelNodes.get(description));
}
return deterministicDeltaRelNodes;
}

/**
* Convert an input RelNode to an incremental RelNode. Populates snapshotRelNodes and deltaRelNodes.
* @param originalNode input RelNode to generate an incremental version for.
*/
public RelNode convertRelIncremental(RelNode originalNode) {
RelShuttle converter = new RelShuttleImpl() {
@Override
public RelNode visit(TableScan scan) {
RelOptTable originalTable = scan.getTable();

// Set RelNodeIncrementalTransformer class relOptSchema if not already set
if (relOptSchema == null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the use case for this check? when will relOptSchema not be present in the table?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is to set relOptSchema only when it has not previously already been set. I don't think it's absolutely necessary but I just added a comment to explain the conditional's purpose.

relOptSchema = originalTable.getRelOptSchema();
}

// Create delta scan
List<String> incrementalNames = new ArrayList<>(originalTable.getQualifiedName());
String deltaTableName = incrementalNames.remove(incrementalNames.size() - 1) + "_delta";
String deltaTableName = incrementalNames.remove(incrementalNames.size() - 1) + DELTA_SUFFIX;
incrementalNames.add(deltaTableName);
RelOptTable incrementalTable =
RelOptTableImpl.create(originalTable.getRelOptSchema(), originalTable.getRowType(), incrementalNames, null);
Expand All @@ -54,12 +105,34 @@ public RelNode visit(LogicalJoin join) {

RexBuilder rexBuilder = join.getCluster().getRexBuilder();

// Check if we can replace the left and right nodes with a scan of a materialized table
String leftDescription = getDescriptionFromRelNode(left, false);
String leftIncrementalDescription = getDescriptionFromRelNode(left, true);
if (snapshotRelNodes.containsKey(leftDescription)) {
left =
susbstituteWithMaterializedView(getDeterministicDescriptionFromDescription(leftDescription, false), left);
incrementalLeft = susbstituteWithMaterializedView(
getDeterministicDescriptionFromDescription(leftIncrementalDescription, true), incrementalLeft);
}
String rightDescription = getDescriptionFromRelNode(right, false);
String rightIncrementalDescription = getDescriptionFromRelNode(right, true);
if (snapshotRelNodes.containsKey(rightDescription)) {
right = susbstituteWithMaterializedView(getDeterministicDescriptionFromDescription(rightDescription, false),
right);
incrementalRight = susbstituteWithMaterializedView(
getDeterministicDescriptionFromDescription(rightIncrementalDescription, true), incrementalRight);
}

// We need to do this in the join to get potentially updated left and right nodes
tempLastRelNode = createProjectOverJoin(join, left, right, rexBuilder);

LogicalProject p1 = createProjectOverJoin(join, left, incrementalRight, rexBuilder);
LogicalProject p2 = createProjectOverJoin(join, incrementalLeft, right, rexBuilder);
LogicalProject p3 = createProjectOverJoin(join, incrementalLeft, incrementalRight, rexBuilder);

LogicalUnion unionAllJoins =
LogicalUnion.create(Arrays.asList(LogicalUnion.create(Arrays.asList(p1, p2), true), p3), true);

return unionAllJoins;
}

Expand All @@ -72,7 +145,16 @@ public RelNode visit(LogicalFilter filter) {
@Override
public RelNode visit(LogicalProject project) {
RelNode transformedChild = convertRelIncremental(project.getInput());
return LogicalProject.create(transformedChild, project.getProjects(), project.getRowType());
RelNode materializedProject = getTempLastRelNode();
if (materializedProject != null) {
snapshotRelNodes.put(getDescriptionFromRelNode(project, false), materializedProject);
} else {
snapshotRelNodes.put(getDescriptionFromRelNode(project, false), project);
}
LogicalProject transformedProject =
LogicalProject.create(transformedChild, project.getProjects(), project.getRowType());
deltaRelNodes.put(getDescriptionFromRelNode(project, true), transformedProject);
return transformedProject;
}

@Override
Expand All @@ -93,8 +175,67 @@ public RelNode visit(LogicalAggregate aggregate) {
return originalNode.accept(converter);
}

private static LogicalProject createProjectOverJoin(LogicalJoin join, RelNode left, RelNode right,
RexBuilder rexBuilder) {
/**
* Returns the tempLastRelNode and sets the variable back to null. Should only be called once for each retrieval
* instance since subsequent consecutive calls will yield null.
*/
private RelNode getTempLastRelNode() {
RelNode currentTempLastRelNode = tempLastRelNode;
tempLastRelNode = null;
return currentTempLastRelNode;
}

/**
* Returns the corresponding description for a given RelNode by extracting the identifier (ex. the identifier for
* LogicalProject#22 is 22) and prepending the TABLE_NAME_PREFIX. Depending on the delta value, a delta suffix may be
* appended.
* @param relNode RelNode from which the identifier will be retrieved.
* @param delta configure whether to get the delta name
*/
private String getDescriptionFromRelNode(RelNode relNode, boolean delta) {
String identifier = relNode.getDescription().split("#")[1];
String description = TABLE_NAME_PREFIX + identifier;
if (delta) {
return description + DELTA_SUFFIX;
}
return description;
}

/**
* Returns a description based on mapping index order that will stay the same across different runs of the same
* query. The description consists of the table prefix, the index, and optionally, the delta suffix.
* @param description output from calling getDescriptionFromRelNode()
* @param delta configure whether to get the delta name
*/
private String getDeterministicDescriptionFromDescription(String description, boolean delta) {
if (delta) {
List<String> deltaKeyOrdering = new ArrayList<>(deltaRelNodes.keySet());
return TABLE_NAME_PREFIX + deltaKeyOrdering.indexOf(description) + DELTA_SUFFIX;
} else {
List<String> snapshotKeyOrdering = new ArrayList<>(snapshotRelNodes.keySet());
return TABLE_NAME_PREFIX + snapshotKeyOrdering.indexOf(description);
}
}

/**
* Accepts a table name and RelNode and creates a TableScan over the RelNode using the class relOptSchema.
* @param relOptTableName table name corresponding to table to scan over
* @param relNode top-level RelNode that will be replaced with the TableScan
*/
private TableScan susbstituteWithMaterializedView(String relOptTableName, RelNode relNode) {
RelOptTable table =
RelOptTableImpl.create(relOptSchema, relNode.getRowType(), Collections.singletonList(relOptTableName), null);
return LogicalTableScan.create(relNode.getCluster(), table);
}

/** Creates a LogicalProject whose input is an incremental LogicalJoin node that is constructed from a left and right
* RelNode and LogicalJoin.
* @param join LogicalJoin to create the incremental join from
* @param left left RelNode child of the incremental join
* @param right right RelNode child of the incremental join
* @param rexBuilder RexBuilder for LogicalProject creation
*/
private LogicalProject createProjectOverJoin(LogicalJoin join, RelNode left, RelNode right, RexBuilder rexBuilder) {
LogicalJoin incrementalJoin =
LogicalJoin.create(left, right, join.getCondition(), join.getVariablesSet(), join.getJoinType());
ArrayList<RexNode> projects = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import java.io.File;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;

import org.apache.calcite.rel.RelNode;
import org.apache.calcite.sql.SqlNode;
Expand Down Expand Up @@ -41,7 +43,8 @@ public void afterClass() throws IOException {
}

public String convert(RelNode relNode) {
RelNode incrementalRelNode = RelNodeIncrementalTransformer.convertRelIncremental(relNode);
RelNodeIncrementalTransformer transformer = new RelNodeIncrementalTransformer();
RelNode incrementalRelNode = transformer.convertRelIncremental(relNode);
CoralRelToSqlNodeConverter converter = new CoralRelToSqlNodeConverter();
SqlNode sqlNode = converter.convert(incrementalRelNode);
return sqlNode.toSqlString(converter.INSTANCE).getSql();
Expand All @@ -52,6 +55,28 @@ public String getIncrementalModification(String sql) {
return convert(originalRelNode);
}

public void checkAllSnapshotAndDeltaQueries(String sql, Map<String, String> snapshotExpected,
Map<String, String> deltaExpected) {
RelNode originalRelNode = hiveToRelConverter.convertSql(sql);
CoralRelToSqlNodeConverter converter = new CoralRelToSqlNodeConverter();
RelNodeIncrementalTransformer transformer = new RelNodeIncrementalTransformer();
transformer.convertRelIncremental(originalRelNode);
Map<String, RelNode> snapshotRelNodes = transformer.getSnapshotRelNodes();
Map<String, RelNode> deltaRelNodes = transformer.getDeltaRelNodes();
for (String key : snapshotRelNodes.keySet()) {
RelNode actualSnapshotRelNode = snapshotRelNodes.get(key);
SqlNode sqlNode = converter.convert(actualSnapshotRelNode);
String actualSql = sqlNode.toSqlString(converter.INSTANCE).getSql();
assertEquals(actualSql, snapshotExpected.get(key));
}
for (String key : deltaRelNodes.keySet()) {
RelNode actualDeltaRelNode = deltaRelNodes.get(key);
SqlNode sqlNode = converter.convert(actualDeltaRelNode);
String actualSql = sqlNode.toSqlString(converter.INSTANCE).getSql();
assertEquals(actualSql, deltaExpected.get(key));
}
}

@Test
public void testSimpleSelectAll() {
String sql = "SELECT * FROM test.foo";
Expand Down Expand Up @@ -81,41 +106,6 @@ public void testJoinWithFilter() {
assertEquals(getIncrementalModification(sql), expected);
}

@Test
public void testJoinWithNestedFilter() {
String sql =
"WITH tmp AS (SELECT * from test.bar1 WHERE test.bar1.x > 10), tmp2 AS (SELECT * from test.bar2) SELECT * FROM tmp JOIN tmp2 ON tmp.x = tmp2.x";
String expected = "SELECT *\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM test.bar1 AS bar1\n"
+ "WHERE bar1.x > 10) AS t\n" + "INNER JOIN test.bar2_delta AS bar2_delta ON t.x = bar2_delta.x\n"
+ "UNION ALL\n" + "SELECT *\n" + "FROM (SELECT *\n" + "FROM test.bar1_delta AS bar1_delta\n"
+ "WHERE bar1_delta.x > 10) AS t0\n" + "INNER JOIN test.bar2 AS bar2 ON t0.x = bar2.x) AS t1\n" + "UNION ALL\n"
+ "SELECT *\n" + "FROM (SELECT *\n" + "FROM test.bar1_delta AS bar1_delta0\n"
+ "WHERE bar1_delta0.x > 10) AS t2\n" + "INNER JOIN test.bar2_delta AS bar2_delta0 ON t2.x = bar2_delta0.x";
assertEquals(getIncrementalModification(sql), expected);
}

@Test
public void testNestedJoin() {
String sql =
"WITH tmp AS (SELECT * FROM test.bar1 INNER JOIN test.bar2 ON test.bar1.x = test.bar2.x) SELECT * FROM tmp INNER JOIN test.bar3 ON tmp.x = test.bar3.x";
String expected = "SELECT *\n" + "FROM (SELECT *\n" + "FROM test.bar1 AS bar1\n"
+ "INNER JOIN test.bar2 AS bar2 ON bar1.x = bar2.x\n"
+ "INNER JOIN test.bar3_delta AS bar3_delta ON bar1.x = bar3_delta.x\n" + "UNION ALL\n" + "SELECT *\n"
+ "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM test.bar1 AS bar10\n"
+ "INNER JOIN test.bar2_delta AS bar2_delta ON bar10.x = bar2_delta.x\n" + "UNION ALL\n" + "SELECT *\n"
+ "FROM test.bar1_delta AS bar1_delta\n" + "INNER JOIN test.bar2 AS bar20 ON bar1_delta.x = bar20.x) AS t\n"
+ "UNION ALL\n" + "SELECT *\n" + "FROM test.bar1_delta AS bar1_delta0\n"
+ "INNER JOIN test.bar2_delta AS bar2_delta0 ON bar1_delta0.x = bar2_delta0.x) AS t0\n"
+ "INNER JOIN test.bar3 AS bar3 ON t0.x = bar3.x) AS t1\n" + "UNION ALL\n" + "SELECT *\n" + "FROM (SELECT *\n"
+ "FROM (SELECT *\n" + "FROM test.bar1 AS bar11\n"
+ "INNER JOIN test.bar2_delta AS bar2_delta1 ON bar11.x = bar2_delta1.x\n" + "UNION ALL\n" + "SELECT *\n"
+ "FROM test.bar1_delta AS bar1_delta1\n" + "INNER JOIN test.bar2 AS bar21 ON bar1_delta1.x = bar21.x) AS t2\n"
+ "UNION ALL\n" + "SELECT *\n" + "FROM test.bar1_delta AS bar1_delta2\n"
+ "INNER JOIN test.bar2_delta AS bar2_delta2 ON bar1_delta2.x = bar2_delta2.x) AS t3\n"
+ "INNER JOIN test.bar3_delta AS bar3_delta0 ON t3.x = bar3_delta0.x";
assertEquals(getIncrementalModification(sql), expected);
}

@Test
public void testUnion() {
String sql = "SELECT * FROM test.bar1 UNION SELECT * FROM test.bar2 UNION SELECT * FROM test.bar3";
Expand Down Expand Up @@ -143,4 +133,68 @@ public void testSelectSpecificJoin() {
+ "INNER JOIN test.bar2_delta AS bar2_delta0 ON bar1_delta0.x = bar2_delta0.x) AS t0";
assertEquals(getIncrementalModification(sql), expected);
}

@Test
public void testNestedJoin() {
String nestedJoin = "SELECT a1, a2 FROM test.alpha JOIN test.beta ON test.alpha.a1 = test.beta.b1";
String sql = "SELECT a2, g1 FROM (" + nestedJoin + ") AS nj JOIN test.gamma ON nj.a2 = test.gamma.g2";
Map<String, String> snapshotExpected = new LinkedHashMap<>();
snapshotExpected.put("Table#0",
"SELECT *\n" + "FROM test.alpha AS alpha\n" + "INNER JOIN test.beta AS beta ON alpha.a1 = beta.b1");
snapshotExpected.put("Table#1",
"SELECT *\n" + "FROM Table#0 AS Table#0\n" + "INNER JOIN test.gamma AS gamma ON Table#0.a2 = gamma.g2");
Map<String, String> deltaExpected = new LinkedHashMap<>();
deltaExpected.put("Table#0_delta",
"SELECT t0.a1, t0.a2\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM test.alpha AS alpha0\n"
+ "INNER JOIN test.beta_delta AS beta_delta ON alpha0.a1 = beta_delta.b1\n" + "UNION ALL\n" + "SELECT *\n"
+ "FROM test.alpha_delta AS alpha_delta\n"
+ "INNER JOIN test.beta AS beta0 ON alpha_delta.a1 = beta0.b1) AS t\n" + "UNION ALL\n" + "SELECT *\n"
+ "FROM test.alpha_delta AS alpha_delta0\n"
+ "INNER JOIN test.beta_delta AS beta_delta0 ON alpha_delta0.a1 = beta_delta0.b1) AS t0");
deltaExpected.put("Table#1_delta",
"SELECT t3.a2, t3.g1\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM Table#0 AS Table#00\n"
+ "INNER JOIN test.gamma_delta AS gamma_delta ON Table#00.a2 = gamma_delta.g2\n" + "UNION ALL\n"
+ "SELECT *\n" + "FROM Table#0_delta AS Table#0_delta\n"
+ "INNER JOIN test.gamma AS gamma0 ON Table#0_delta.a2 = gamma0.g2) AS t2\n" + "UNION ALL\n" + "SELECT *\n"
+ "FROM Table#0_delta AS Table#0_delta0\n"
+ "INNER JOIN test.gamma_delta AS gamma_delta0 ON Table#0_delta0.a2 = gamma_delta0.g2) AS t3");
checkAllSnapshotAndDeltaQueries(sql, snapshotExpected, deltaExpected);
}

@Test
public void testThreeNestedJoins() {
String nestedJoin1 = "SELECT a1, a2 FROM test.alpha JOIN test.beta ON test.alpha.a1 = test.beta.b1";
String nestedJoin2 = "SELECT a2, g1 FROM (" + nestedJoin1 + ") AS nj1 JOIN test.gamma ON nj1.a2 = test.gamma.g2";
String sql = "SELECT g1, e2 FROM (" + nestedJoin2 + ") AS nj2 JOIN test.epsilon ON nj2.g1 = test.epsilon.e1";
Map<String, String> snapshotExpected = new LinkedHashMap<>();
snapshotExpected.put("Table#0",
"SELECT *\n" + "FROM test.alpha AS alpha\n" + "INNER JOIN test.beta AS beta ON alpha.a1 = beta.b1");
snapshotExpected.put("Table#1",
"SELECT *\n" + "FROM Table#0 AS Table#0\n" + "INNER JOIN test.gamma AS gamma ON Table#0.a2 = gamma.g2");
snapshotExpected.put("Table#2",
"SELECT *\n" + "FROM Table#1 AS Table#1\n" + "INNER JOIN test.epsilon AS epsilon ON Table#1.g1 = epsilon.e1");
Map<String, String> deltaExpected = new LinkedHashMap<>();
deltaExpected.put("Table#0_delta",
"SELECT t0.a1, t0.a2\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM test.alpha AS alpha0\n"
+ "INNER JOIN test.beta_delta AS beta_delta ON alpha0.a1 = beta_delta.b1\n" + "UNION ALL\n" + "SELECT *\n"
+ "FROM test.alpha_delta AS alpha_delta\n"
+ "INNER JOIN test.beta AS beta0 ON alpha_delta.a1 = beta0.b1) AS t\n" + "UNION ALL\n" + "SELECT *\n"
+ "FROM test.alpha_delta AS alpha_delta0\n"
+ "INNER JOIN test.beta_delta AS beta_delta0 ON alpha_delta0.a1 = beta_delta0.b1) AS t0");
deltaExpected.put("Table#1_delta",
"SELECT t3.a2, t3.g1\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM Table#0 AS Table#00\n"
+ "INNER JOIN test.gamma_delta AS gamma_delta ON Table#00.a2 = gamma_delta.g2\n" + "UNION ALL\n"
+ "SELECT *\n" + "FROM Table#0_delta AS Table#0_delta\n"
+ "INNER JOIN test.gamma AS gamma0 ON Table#0_delta.a2 = gamma0.g2) AS t2\n" + "UNION ALL\n" + "SELECT *\n"
+ "FROM Table#0_delta AS Table#0_delta0\n"
+ "INNER JOIN test.gamma_delta AS gamma_delta0 ON Table#0_delta0.a2 = gamma_delta0.g2) AS t3");
deltaExpected.put("Table#2_delta",
"SELECT t6.g1, t6.e2\n" + "FROM (SELECT *\n" + "FROM (SELECT *\n" + "FROM Table#1 AS Table#10\n"
+ "INNER JOIN test.epsilon_delta AS epsilon_delta ON Table#10.g1 = epsilon_delta.e1\n" + "UNION ALL\n"
+ "SELECT *\n" + "FROM Table#1_delta AS Table#1_delta\n"
+ "INNER JOIN test.epsilon AS epsilon0 ON Table#1_delta.g1 = epsilon0.e1) AS t5\n" + "UNION ALL\n"
+ "SELECT *\n" + "FROM Table#1_delta AS Table#1_delta0\n"
+ "INNER JOIN test.epsilon_delta AS epsilon_delta0 ON Table#1_delta0.g1 = epsilon_delta0.e1) AS t6");
checkAllSnapshotAndDeltaQueries(sql, snapshotExpected, deltaExpected);
}
}
Loading