Skip to content

Commit

Permalink
修复对groupby解析对支持
Browse files Browse the repository at this point in the history
  • Loading branch information
xuchao committed Mar 30, 2020
1 parent 2833176 commit a7967f7
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 38 deletions.
77 changes: 56 additions & 21 deletions core/src/main/java/com/dtstack/flink/sql/side/JoinNodeDealer.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ public JoinInfo dealJoinNode(SqlJoin joinNode,
Queue<Object> queueInfo,
SqlNode parentWhere,
SqlNodeList parentSelectList,
SqlNodeList parentGroupByList,
Set<Tuple2<String, String>> joinFieldSet,
Map<String, String> tableRef,
Map<String, String> fieldRef) {
Expand All @@ -105,20 +106,20 @@ public JoinInfo dealJoinNode(SqlJoin joinNode,
if (leftNode.getKind() == JOIN) {
//处理连续join
dealNestJoin(joinNode, sideTableSet,
queueInfo, parentWhere, parentSelectList, joinFieldSet, tableRef, fieldRef, parentSelectList);
queueInfo, parentWhere, parentSelectList, parentGroupByList, joinFieldSet, tableRef, fieldRef);
leftNode = joinNode.getLeft();
}

if (leftNode.getKind() == AS) {
AliasInfo aliasInfo = (AliasInfo) sideSQLParser.parseSql(leftNode, sideTableSet, queueInfo, parentWhere, parentSelectList);
AliasInfo aliasInfo = (AliasInfo) sideSQLParser.parseSql(leftNode, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList);
leftTbName = aliasInfo.getName();
leftTbAlias = aliasInfo.getAlias();
}

boolean leftIsSide = checkIsSideTable(leftTbName, sideTableSet);
Preconditions.checkState(!leftIsSide, "side-table must be at the right of join operator");

Tuple2<String, String> rightTableNameAndAlias = parseRightNode(rightNode, sideTableSet, queueInfo, parentWhere, parentSelectList);
Tuple2<String, String> rightTableNameAndAlias = parseRightNode(rightNode, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList);
rightTableName = rightTableNameAndAlias.f0;
rightTableAlias = rightTableNameAndAlias.f1;

Expand All @@ -145,7 +146,7 @@ public JoinInfo dealJoinNode(SqlJoin joinNode,

//extract 需要查询的字段信息
if(rightIsSide){
extractJoinNeedSelectField(leftNode, rightNode, parentWhere, parentSelectList, tableRef, joinFieldSet, fieldRef, tableInfo);
extractJoinNeedSelectField(leftNode, rightNode, parentWhere, parentSelectList, parentGroupByList, tableRef, joinFieldSet, fieldRef, tableInfo);
}

if(tableInfo.getLeftNode().getKind() != AS){
Expand All @@ -168,13 +169,14 @@ public void extractJoinNeedSelectField(SqlNode leftNode,
SqlNode rightNode,
SqlNode parentWhere,
SqlNodeList parentSelectList,
SqlNodeList parentGroupByList,
Map<String, String> tableRef,
Set<Tuple2<String, String>> joinFieldSet,
Map<String, String> fieldRef,
JoinInfo tableInfo){

Set<String> extractSelectField = extractField(leftNode, parentWhere, parentSelectList, tableRef, joinFieldSet);
Set<String> rightExtractSelectField = extractField(rightNode, parentWhere, parentSelectList, tableRef, joinFieldSet);
Set<String> extractSelectField = extractField(leftNode, parentWhere, parentSelectList, parentGroupByList, tableRef, joinFieldSet);
Set<String> rightExtractSelectField = extractField(rightNode, parentWhere, parentSelectList, parentGroupByList, tableRef, joinFieldSet);

//重命名right 中和 left 重名的
Map<String, String> leftTbSelectField = Maps.newHashMap();
Expand Down Expand Up @@ -208,13 +210,15 @@ public void extractJoinNeedSelectField(SqlNode leftNode,
* @param sqlNode
* @param parentWhere
* @param parentSelectList
* @param parentGroupByList
* @param tableRef
* @param joinFieldSet
* @return
*/
public Set<String> extractField(SqlNode sqlNode,
SqlNode parentWhere,
SqlNodeList parentSelectList,
SqlNodeList parentGroupByList,
Map<String, String> tableRef,
Set<Tuple2<String, String>> joinFieldSet){
Set<String> fromTableNameSet = Sets.newHashSet();
Expand All @@ -225,8 +229,11 @@ public Set<String> extractField(SqlNode sqlNode,
Set<String> extractSelectField = extractSelectFields(parentSelectList, fromTableNameSet, tableRef);
Set<String> fieldFromJoinCondition = extractSelectFieldFromJoinCondition(joinFieldSet, fromTableNameSet, tableRef);

Set<String> extractGroupByField = extractFieldFromGroupByList(parentGroupByList, fromTableNameSet, tableRef);

extractSelectField.addAll(extractCondition);
extractSelectField.addAll(fieldFromJoinCondition);
extractSelectField.addAll(extractGroupByField);

return extractSelectField;
}
Expand All @@ -242,27 +249,27 @@ private JoinInfo dealNestJoin(SqlJoin joinNode,
Set<String> sideTableSet,
Queue<Object> queueInfo,
SqlNode parentWhere,
SqlNodeList selectList,
SqlNodeList parentSelectList,
SqlNodeList parentGroupByList,
Set<Tuple2<String, String>> joinFieldSet,
Map<String, String> tableRef,
Map<String, String> fieldRef,
SqlNodeList parentSelectList){
Map<String, String> fieldRef){

SqlJoin leftJoinNode = (SqlJoin) joinNode.getLeft();
SqlNode parentRightJoinNode = joinNode.getRight();
SqlNode rightNode = leftJoinNode.getRight();
Tuple2<String, String> rightTableNameAndAlias = parseRightNode(rightNode, sideTableSet, queueInfo, parentWhere, selectList);
Tuple2<String, String> parentRightJoinInfo = parseRightNode(parentRightJoinNode, sideTableSet, queueInfo, parentWhere, selectList);
Tuple2<String, String> rightTableNameAndAlias = parseRightNode(rightNode, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList);
Tuple2<String, String> parentRightJoinInfo = parseRightNode(parentRightJoinNode, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList);
boolean parentRightIsSide = checkIsSideTable(parentRightJoinInfo.f0, sideTableSet);

JoinInfo joinInfo = dealJoinNode(leftJoinNode, sideTableSet, queueInfo, parentWhere, selectList, joinFieldSet, tableRef, fieldRef);
JoinInfo joinInfo = dealJoinNode(leftJoinNode, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList, joinFieldSet, tableRef, fieldRef);

String rightTableName = rightTableNameAndAlias.f0;
boolean rightIsSide = checkIsSideTable(rightTableName, sideTableSet);
SqlBasicCall buildAs = TableUtils.buildAsNodeByJoinInfo(joinInfo, null, null);

if(rightIsSide){
addSideInfoToExeQueue(queueInfo, joinInfo, joinNode, parentSelectList, parentWhere, tableRef);
addSideInfoToExeQueue(queueInfo, joinInfo, joinNode, parentSelectList, parentGroupByList, parentWhere, tableRef);
}

SqlNode newLeftNode = joinNode.getLeft();
Expand All @@ -275,7 +282,7 @@ private JoinInfo dealNestJoin(SqlJoin joinNode,

//替换leftNode 为新的查询
joinNode.setLeft(buildAs);
replaceSelectAndWhereField(buildAs, leftJoinNode, tableRef, parentSelectList, parentWhere);
replaceSelectAndWhereField(buildAs, leftJoinNode, tableRef, parentSelectList, parentGroupByList, parentWhere);
}

return joinInfo;
Expand All @@ -288,13 +295,15 @@ private JoinInfo dealNestJoin(SqlJoin joinNode,
* @param joinInfo
* @param joinNode
* @param parentSelectList
* @param parentGroupByList
* @param parentWhere
* @param tableRef
*/
public void addSideInfoToExeQueue(Queue<Object> queueInfo,
JoinInfo joinInfo,
SqlJoin joinNode,
SqlNodeList parentSelectList,
SqlNodeList parentGroupByList,
SqlNode parentWhere,
Map<String, String> tableRef){
//只处理维表
Expand All @@ -308,7 +317,7 @@ public void addSideInfoToExeQueue(Queue<Object> queueInfo,
//替换左表为新的表名称
joinNode.setLeft(buildAs);

replaceSelectAndWhereField(buildAs, leftJoinNode, tableRef, parentSelectList, parentWhere);
replaceSelectAndWhereField(buildAs, leftJoinNode, tableRef, parentSelectList, parentGroupByList, parentWhere);
}

/**
Expand All @@ -317,12 +326,14 @@ public void addSideInfoToExeQueue(Queue<Object> queueInfo,
* @param leftJoinNode
* @param tableRef
* @param parentSelectList
* @param parentGroupByList
* @param parentWhere
*/
public void replaceSelectAndWhereField(SqlBasicCall buildAs,
SqlNode leftJoinNode,
Map<String, String> tableRef,
SqlNodeList parentSelectList,
SqlNodeList parentGroupByList,
SqlNode parentWhere){

String newLeftTableName = buildAs.getOperands()[1].toString();
Expand All @@ -341,10 +352,20 @@ public void replaceSelectAndWhereField(SqlBasicCall buildAs,
}
}

//TODO 应该根据上面的查询字段的关联关系来替换
//替换where 中的条件相关
for(String tbTmp : fromTableNameSet){
TableUtils.replaceWhereCondition(parentWhere, tbTmp, newLeftTableName);
TableUtils.replaceWhereCondition(parentWhere, tbTmp, newLeftTableName, fieldReplaceRef);
}

if(parentGroupByList != null){
for(SqlNode sqlNode : parentGroupByList.getList()){
for(String tbTmp : fromTableNameSet) {
TableUtils.replaceSelectFieldTable(sqlNode, tbTmp, newLeftTableName, fieldReplaceRef);
}
}
}

}

/**
Expand Down Expand Up @@ -407,7 +428,7 @@ private void extractTemporaryQuery(SqlNode node, String tableAlias,

//替换where 中的条件相关
for(String tbTmp : fromTableNameSet){
TableUtils.replaceWhereCondition(parentWhere, tbTmp, tableAlias);
TableUtils.replaceWhereCondition(parentWhere, tbTmp, tableAlias, fieldReplaceRef);
}

for(String tbTmp : fromTableNameSet){
Expand All @@ -426,7 +447,6 @@ private void extractTemporaryQuery(SqlNode node, String tableAlias,

/**
* 抽取上层需用使用到的字段
* 由于where字段已经抽取到上一层了所以不用查询出来
* @param parentSelectList
* @param fromTableNameSet
* @return
Expand All @@ -451,7 +471,6 @@ private Set<String> extractSelectFieldFromJoinCondition(Set<Tuple2<String, Strin
extractFieldList.add(field.f0 + "." + field.f1);
}

//TODO
if(tableRef.containsKey(field.f0)){
if(fromTableNameSet.contains(tableRef.get(field.f0))){
extractFieldList.add(tableRef.get(field.f0) + "." + field.f1);
Expand All @@ -462,6 +481,22 @@ private Set<String> extractSelectFieldFromJoinCondition(Set<Tuple2<String, Strin
return extractFieldList;
}

private Set<String> extractFieldFromGroupByList(SqlNodeList parentGroupByList,
Set<String> fromTableNameSet,
Map<String, String> tableRef){

if(parentGroupByList == null){
return Sets.newHashSet();
}

Set<String> extractFieldList = Sets.newHashSet();
for(SqlNode selectNode : parentGroupByList.getList()){
extractSelectField(selectNode, extractFieldList, fromTableNameSet, tableRef);
}

return extractFieldList;
}

/**
* 从join的条件中获取字段信息
* @param condition
Expand Down Expand Up @@ -573,12 +608,12 @@ private void extractSelectField(SqlNode selectNode,


private Tuple2<String, String> parseRightNode(SqlNode sqlNode, Set<String> sideTableSet, Queue<Object> queueInfo,
SqlNode parentWhere, SqlNodeList selectList) {
SqlNode parentWhere, SqlNodeList selectList, SqlNodeList parentGroupByList) {
Tuple2<String, String> tabName = new Tuple2<>("", "");
if(sqlNode.getKind() == IDENTIFIER){
tabName.f0 = sqlNode.toString();
}else{
AliasInfo aliasInfo = (AliasInfo)sideSQLParser.parseSql(sqlNode, sideTableSet, queueInfo, parentWhere, selectList);
AliasInfo aliasInfo = (AliasInfo)sideSQLParser.parseSql(sqlNode, sideTableSet, queueInfo, parentWhere, selectList, parentGroupByList);
tabName.f0 = aliasInfo.getName();
tabName.f1 = aliasInfo.getAlias();
}
Expand Down
28 changes: 17 additions & 11 deletions core/src/main/java/com/dtstack/flink/sql/side/SideSQLParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public Queue<Object> getExeQueue(String exeSql, Set<String> sideTableSet) throws
SqlParser sqlParser = SqlParser.create(exeSql, CalciteConfig.MYSQL_LEX_CONFIG);
SqlNode sqlNode = sqlParser.parseStmt();

parseSql(sqlNode, sideTableSet, queueInfo, null, null);
parseSql(sqlNode, sideTableSet, queueInfo, null, null, null);
queueInfo.offer(sqlNode);
return queueInfo;
}
Expand All @@ -100,30 +100,36 @@ public Queue<Object> getExeQueue(String exeSql, Set<String> sideTableSet) throws
* @param parentSelectList
* @return
*/
public Object parseSql(SqlNode sqlNode, Set<String> sideTableSet, Queue<Object> queueInfo, SqlNode parentWhere, SqlNodeList parentSelectList){
public Object parseSql(SqlNode sqlNode,
Set<String> sideTableSet,
Queue<Object> queueInfo,
SqlNode parentWhere,
SqlNodeList parentSelectList,
SqlNodeList parentGroupByList){
SqlKind sqlKind = sqlNode.getKind();
switch (sqlKind){
case WITH: {
SqlWith sqlWith = (SqlWith) sqlNode;
SqlNodeList sqlNodeList = sqlWith.withList;
for (SqlNode withAsTable : sqlNodeList) {
SqlWithItem sqlWithItem = (SqlWithItem) withAsTable;
parseSql(sqlWithItem.query, sideTableSet, queueInfo, parentWhere, parentSelectList);
parseSql(sqlWithItem.query, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList);
queueInfo.add(sqlWithItem);
}
parseSql(sqlWith.body, sideTableSet, queueInfo, parentWhere, parentSelectList);
parseSql(sqlWith.body, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList);
break;
}
case INSERT:
SqlNode sqlSource = ((SqlInsert)sqlNode).getSource();
return parseSql(sqlSource, sideTableSet, queueInfo, parentWhere, parentSelectList);
return parseSql(sqlSource, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList);
case SELECT:
SqlNode sqlFrom = ((SqlSelect)sqlNode).getFrom();
SqlNode sqlWhere = ((SqlSelect)sqlNode).getWhere();
SqlNodeList selectList = ((SqlSelect)sqlNode).getSelectList();
SqlNodeList groupByList = ((SqlSelect) sqlNode).getGroup();

if(sqlFrom.getKind() != IDENTIFIER){
Object result = parseSql(sqlFrom, sideTableSet, queueInfo, sqlWhere, selectList);
Object result = parseSql(sqlFrom, sideTableSet, queueInfo, sqlWhere, selectList, groupByList);
if(result instanceof JoinInfo){
return TableUtils.dealSelectResultWithJoinInfo((JoinInfo) result, (SqlSelect) sqlNode, queueInfo);
}else if(result instanceof AliasInfo){
Expand All @@ -145,7 +151,7 @@ public Object parseSql(SqlNode sqlNode, Set<String> sideTableSet, Queue<Object>
Map<String, String> tableRef = Maps.newHashMap();
Map<String, String> fieldRef = Maps.newHashMap();
return joinNodeDealer.dealJoinNode((SqlJoin) sqlNode, sideTableSet, queueInfo,
parentWhere, parentSelectList, joinFieldSet, tableRef, fieldRef);
parentWhere, parentSelectList, parentGroupByList, joinFieldSet, tableRef, fieldRef);
case AS:
SqlNode info = ((SqlBasicCall)sqlNode).getOperands()[0];
SqlNode alias = ((SqlBasicCall) sqlNode).getOperands()[1];
Expand All @@ -154,7 +160,7 @@ public Object parseSql(SqlNode sqlNode, Set<String> sideTableSet, Queue<Object>
if(info.getKind() == IDENTIFIER){
infoStr = info.toString();
} else {
infoStr = parseSql(info, sideTableSet, queueInfo, parentWhere, parentSelectList).toString();
infoStr = parseSql(info, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList).toString();
}

AliasInfo aliasInfo = new AliasInfo();
Expand All @@ -167,12 +173,12 @@ public Object parseSql(SqlNode sqlNode, Set<String> sideTableSet, Queue<Object>
SqlNode unionLeft = ((SqlBasicCall)sqlNode).getOperands()[0];
SqlNode unionRight = ((SqlBasicCall)sqlNode).getOperands()[1];

parseSql(unionLeft, sideTableSet, queueInfo, parentWhere, parentSelectList);
parseSql(unionRight, sideTableSet, queueInfo, parentWhere, parentSelectList);
parseSql(unionLeft, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList);
parseSql(unionRight, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList);
break;
case ORDER_BY:
SqlOrderBy sqlOrderBy = (SqlOrderBy) sqlNode;
parseSql(sqlOrderBy.query, sideTableSet, queueInfo, parentWhere, parentSelectList);
parseSql(sqlOrderBy.query, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList);
}
return "";
}
Expand Down
Loading

0 comments on commit a7967f7

Please sign in to comment.