diff --git a/core/src/main/java/com/dtstack/flink/sql/config/CalciteConfig.java b/core/src/main/java/com/dtstack/flink/sql/config/CalciteConfig.java deleted file mode 100644 index 54ae66bbc..000000000 --- a/core/src/main/java/com/dtstack/flink/sql/config/CalciteConfig.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -package com.dtstack.flink.sql.config; - -import org.apache.calcite.config.Lex; -import org.apache.calcite.sql.parser.SqlParser; -import org.apache.calcite.sql.parser.SqlParser.Config; - -public class CalciteConfig { - - public static Config MYSQL_LEX_CONFIG = SqlParser - .configBuilder() - .setLex(Lex.MYSQL) - .build(); - - - -} diff --git a/core/src/main/java/com/dtstack/flink/sql/parser/CreateTmpTableParser.java b/core/src/main/java/com/dtstack/flink/sql/parser/CreateTmpTableParser.java index de7141eb5..114dbd50b 100644 --- a/core/src/main/java/com/dtstack/flink/sql/parser/CreateTmpTableParser.java +++ b/core/src/main/java/com/dtstack/flink/sql/parser/CreateTmpTableParser.java @@ -21,11 +21,10 @@ package com.dtstack.flink.sql.parser; import com.dtstack.flink.sql.util.DtStringUtil; -import org.apache.calcite.config.Lex; import org.apache.calcite.sql.*; -import org.apache.calcite.sql.parser.SqlParseException; -import org.apache.calcite.sql.parser.SqlParser; import com.google.common.collect.Lists; +import org.apache.flink.table.calcite.FlinkPlannerImpl; + import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -71,17 +70,12 @@ public void parseSql(String sql, SqlTree sqlTree) { tableName = matcher.group(1); selectSql = "select " + matcher.group(2); } - - SqlParser.Config config = SqlParser - .configBuilder() - .setLex(Lex.MYSQL) - .build(); - SqlParser sqlParser = SqlParser.create(selectSql,config); + FlinkPlannerImpl flinkPlanner = FlinkPlanner.getFlinkPlanner(); SqlNode sqlNode = null; try { - sqlNode = sqlParser.parseStmt(); - } catch (SqlParseException e) { + sqlNode = flinkPlanner.parse(selectSql); + } catch (Exception e) { throw new RuntimeException("", e); } diff --git a/core/src/main/java/com/dtstack/flink/sql/parser/FlinkPlanner.java b/core/src/main/java/com/dtstack/flink/sql/parser/FlinkPlanner.java new file mode 100644 index 000000000..7c76ec2cd --- /dev/null +++ b/core/src/main/java/com/dtstack/flink/sql/parser/FlinkPlanner.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.dtstack.flink.sql.parser; + +import org.apache.calcite.plan.RelOptPlanner; +import org.apache.calcite.tools.FrameworkConfig; +import org.apache.flink.table.calcite.FlinkPlannerImpl; +import org.apache.flink.table.calcite.FlinkTypeFactory; + +/** + * Date: 2020/3/31 + * Company: www.dtstack.com + * @author maqi + */ +public class FlinkPlanner { + + public static volatile FlinkPlannerImpl flinkPlanner; + + private FlinkPlanner() { + } + + public static FlinkPlannerImpl createFlinkPlanner(FrameworkConfig frameworkConfig, RelOptPlanner relOptPlanner, FlinkTypeFactory typeFactory) { + if (flinkPlanner == null) { + synchronized (FlinkPlanner.class) { + if (flinkPlanner == null) { + flinkPlanner = new FlinkPlannerImpl(frameworkConfig, relOptPlanner, typeFactory); + } + } + } + return flinkPlanner; + } + + public static FlinkPlannerImpl getFlinkPlanner() { + return flinkPlanner; + } +} diff --git a/core/src/main/java/com/dtstack/flink/sql/parser/InsertSqlParser.java b/core/src/main/java/com/dtstack/flink/sql/parser/InsertSqlParser.java index a7c6db9eb..e2940c4f7 100644 --- a/core/src/main/java/com/dtstack/flink/sql/parser/InsertSqlParser.java +++ b/core/src/main/java/com/dtstack/flink/sql/parser/InsertSqlParser.java @@ -57,6 +57,7 @@ public void parseSql(String sql, SqlTree sqlTree) { .configBuilder() .setLex(Lex.MYSQL) .build(); + SqlParser sqlParser = SqlParser.create(sql,config); SqlNode sqlNode = null; try { diff --git a/core/src/main/java/com/dtstack/flink/sql/side/FieldInfo.java b/core/src/main/java/com/dtstack/flink/sql/side/FieldInfo.java index 85bad8c2c..1259ddecf 100644 --- a/core/src/main/java/com/dtstack/flink/sql/side/FieldInfo.java +++ b/core/src/main/java/com/dtstack/flink/sql/side/FieldInfo.java @@ -85,4 +85,13 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(table, fieldName); } + + @Override + public String toString() { + return "FieldInfo{" + + "table='" + table + '\'' + + ", fieldName='" + fieldName + '\'' + + ", typeInformation=" + typeInformation + + '}'; + } } diff --git a/core/src/main/java/com/dtstack/flink/sql/side/FieldReplaceInfo.java b/core/src/main/java/com/dtstack/flink/sql/side/FieldReplaceInfo.java index 703721ef2..dfab231ca 100644 --- a/core/src/main/java/com/dtstack/flink/sql/side/FieldReplaceInfo.java +++ b/core/src/main/java/com/dtstack/flink/sql/side/FieldReplaceInfo.java @@ -23,6 +23,7 @@ import com.google.common.collect.HashBasedTable; import org.apache.commons.lang3.StringUtils; + /** * 用于记录转换之后的表和原来表直接字段的关联关系 * Date: 2018/8/30 @@ -78,7 +79,7 @@ public void setTargetTableAlias(String targetTableAlias) { * @param fieldName * @return */ - public String getTargetFieldName(String tableName, String fieldName){ + public String getTargetFieldName(String tableName, String fieldName) { String targetFieldName = mappingTable.get(tableName, fieldName); if(StringUtils.isNotBlank(targetFieldName)){ return targetFieldName; diff --git a/core/src/main/java/com/dtstack/flink/sql/side/JoinInfo.java b/core/src/main/java/com/dtstack/flink/sql/side/JoinInfo.java index 8a8fe21f6..8854ff4ec 100644 --- a/core/src/main/java/com/dtstack/flink/sql/side/JoinInfo.java +++ b/core/src/main/java/com/dtstack/flink/sql/side/JoinInfo.java @@ -20,6 +20,8 @@ package com.dtstack.flink.sql.side; +import com.google.common.collect.HashBasedTable; +import com.google.common.collect.Maps; import org.apache.calcite.sql.JoinType; import org.apache.calcite.sql.SqlNode; import com.google.common.base.Strings; @@ -31,7 +33,6 @@ * Join信息 * Date: 2018/7/24 * Company: www.dtstack.com - * * @author xuchao */ @@ -40,9 +41,7 @@ public class JoinInfo implements Serializable { private static final long serialVersionUID = -1L; //左表是否是维表 - private boolean leftIsSideTable; - - private boolean leftIsTmpTable = false; + private boolean leftIsSideTable = false; //右表是否是维表 private boolean rightIsSideTable; @@ -67,6 +66,16 @@ public class JoinInfo implements Serializable { private JoinType joinType; + /** + * 左表需要查询的字段信息和output的时候对应的列名称 + */ + private Map leftSelectFieldInfo = Maps.newHashMap(); + + /** + * 右表需要查询的字段信息和output的时候对应的列名称 + */ + private Map rightSelectFieldInfo = Maps.newHashMap(); + public String getSideTableName(){ if(leftIsSideTable){ return leftTableAlias; @@ -195,19 +204,39 @@ public void setJoinType(JoinType joinType) { this.joinType = joinType; } - public boolean isLeftIsTmpTable() { - return leftIsTmpTable; + public Map getLeftSelectFieldInfo() { + return leftSelectFieldInfo; + } + + public void setLeftSelectFieldInfo(Map leftSelectFieldInfo) { + this.leftSelectFieldInfo = leftSelectFieldInfo; } - public void setLeftIsTmpTable(boolean leftIsTmpTable) { - this.leftIsTmpTable = leftIsTmpTable; + public Map getRightSelectFieldInfo() { + return rightSelectFieldInfo; + } + + public void setRightSelectFieldInfo(Map rightSelectFieldInfo) { + this.rightSelectFieldInfo = rightSelectFieldInfo; + } + + public HashBasedTable getTableFieldRef(){ + HashBasedTable mappingTable = HashBasedTable.create(); + getLeftSelectFieldInfo().forEach((key, value) -> { + mappingTable.put(getLeftTableAlias(), key, value); + }); + + getRightSelectFieldInfo().forEach((key, value) -> { + mappingTable.put(getRightTableAlias(), key, value); + }); + + return mappingTable; } @Override public String toString() { return "JoinInfo{" + "leftIsSideTable=" + leftIsSideTable + - ", leftIsTmpTable=" + leftIsTmpTable + ", rightIsSideTable=" + rightIsSideTable + ", leftTableName='" + leftTableName + '\'' + ", leftTableAlias='" + leftTableAlias + '\'' + diff --git a/core/src/main/java/com/dtstack/flink/sql/side/JoinNodeDealer.java b/core/src/main/java/com/dtstack/flink/sql/side/JoinNodeDealer.java index 40ebc22c4..f072e2591 100644 --- a/core/src/main/java/com/dtstack/flink/sql/side/JoinNodeDealer.java +++ b/core/src/main/java/com/dtstack/flink/sql/side/JoinNodeDealer.java @@ -19,11 +19,11 @@ package com.dtstack.flink.sql.side; -import com.dtstack.flink.sql.config.CalciteConfig; +import com.dtstack.flink.sql.parser.FlinkPlanner; +import com.dtstack.flink.sql.util.ParseUtils; import com.dtstack.flink.sql.util.TableUtils; import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; +import com.google.common.collect.*; import org.apache.calcite.sql.JoinType; import org.apache.calcite.sql.SqlAsOperator; import org.apache.calcite.sql.SqlBasicCall; @@ -38,11 +38,11 @@ import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlCase; import org.apache.calcite.sql.fun.SqlStdOperatorTable; -import org.apache.calcite.sql.parser.SqlParser; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.table.calcite.FlinkPlannerImpl; import java.util.List; import java.util.Map; @@ -74,17 +74,23 @@ public JoinNodeDealer(SideSQLParser sideSQLParser){ * 解析 join 操作 * @param joinNode * @param sideTableSet 标明哪些表名是维表 - * @param queueInfo + * @param queueInfo sql执行队列 * @param parentWhere join 关联的最上层的where 节点 * @param parentSelectList join 关联的最上层的select 节点 * @param joinFieldSet * @param tableRef 存储构建临时表查询后源表和新表之间的关联关系 * @return */ - public JoinInfo dealJoinNode(SqlJoin joinNode, Set sideTableSet, - Queue queueInfo, SqlNode parentWhere, - SqlNodeList parentSelectList, Set> joinFieldSet, - Map tableRef) { + public JoinInfo dealJoinNode(SqlJoin joinNode, + Set sideTableSet, + Queue queueInfo, + SqlNode parentWhere, + SqlNodeList parentSelectList, + SqlNodeList parentGroupByList, + Set> joinFieldSet, + Map tableRef, + Map fieldRef) { + SqlNode leftNode = joinNode.getLeft(); SqlNode rightNode = joinNode.getRight(); JoinType joinType = joinNode.getJoinType(); @@ -93,66 +99,30 @@ public JoinInfo dealJoinNode(SqlJoin joinNode, Set sideTableSet, String leftTbAlias = ""; String rightTableName = ""; String rightTableAlias = ""; - boolean leftTbisTmp = false; - //如果是连续join 判断是否已经处理过添加到执行队列 - Boolean needBuildTemp = false; + //抽取join中的的条件 extractJoinField(joinNode.getCondition(), joinFieldSet); - if(leftNode.getKind() == IDENTIFIER){ - leftTbName = leftNode.toString(); - } else if (leftNode.getKind() == JOIN) { + if (leftNode.getKind() == JOIN) { //处理连续join - Tuple2 nestJoinResult = dealNestJoin((SqlJoin) leftNode, sideTableSet, - queueInfo, parentWhere, parentSelectList, joinFieldSet, tableRef); - needBuildTemp = nestJoinResult.f0; - SqlBasicCall buildAs = TableUtils.buildAsNodeByJoinInfo(nestJoinResult.f1, null, null); - - if(needBuildTemp){ - //记录表之间的关联关系 - String newLeftTableName = buildAs.getOperands()[1].toString(); - Set fromTableNameSet = Sets.newHashSet(); - TableUtils.getFromTableInfo(joinNode.getLeft(), fromTableNameSet); - for(String tbTmp : fromTableNameSet){ - tableRef.put(tbTmp, newLeftTableName); - } - - //替换leftNode 为新的查询 - joinNode.setLeft(buildAs); - leftNode = buildAs; - - //替换select field 中的对应字段 - for(SqlNode sqlNode : parentSelectList.getList()){ - for(String tbTmp : fromTableNameSet) { - TableUtils.replaceSelectFieldTable(sqlNode, tbTmp, newLeftTableName); - } - } - - //替换where 中的条件相关 - for(String tbTmp : fromTableNameSet){ - TableUtils.replaceWhereCondition(parentWhere, tbTmp, newLeftTableName); - } - - leftTbisTmp = true; - - } - - leftTbName = buildAs.getOperands()[0].toString(); - leftTbAlias = buildAs.getOperands()[1].toString(); + dealNestJoin(joinNode, sideTableSet, + queueInfo, parentWhere, parentSelectList, parentGroupByList, joinFieldSet, tableRef, fieldRef); + leftNode = joinNode.getLeft(); + } - } else if (leftNode.getKind() == AS) { - AliasInfo aliasInfo = (AliasInfo) sideSQLParser.parseSql(leftNode, sideTableSet, queueInfo, parentWhere, parentSelectList); + if (leftNode.getKind() == AS) { + AliasInfo aliasInfo = (AliasInfo) sideSQLParser.parseSql(leftNode, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList); leftTbName = aliasInfo.getName(); leftTbAlias = aliasInfo.getAlias(); - - } else { - throw new RuntimeException(String.format("---not deal node with type %s", leftNode.getKind().toString())); + } else if(leftNode.getKind() == IDENTIFIER){ + leftTbName = leftNode.toString(); + leftTbAlias = leftTbName; } boolean leftIsSide = checkIsSideTable(leftTbName, sideTableSet); Preconditions.checkState(!leftIsSide, "side-table must be at the right of join operator"); - Tuple2 rightTableNameAndAlias = parseRightNode(rightNode, sideTableSet, queueInfo, parentWhere, parentSelectList); + Tuple2 rightTableNameAndAlias = parseRightNode(rightNode, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList); rightTableName = rightTableNameAndAlias.f0; rightTableAlias = rightTableNameAndAlias.f1; @@ -161,78 +131,264 @@ public JoinInfo dealJoinNode(SqlJoin joinNode, Set sideTableSet, throw new RuntimeException("side join not support join type of right[current support inner join and left join]"); } - if(leftNode.getKind() == JOIN && rightIsSide){ - needBuildTemp = true; - } - JoinInfo tableInfo = new JoinInfo(); tableInfo.setLeftTableName(leftTbName); tableInfo.setRightTableName(rightTableName); - if (StringUtils.isEmpty(leftTbAlias)){ - tableInfo.setLeftTableAlias(leftTbName); - } else { - tableInfo.setLeftTableAlias(leftTbAlias); - } - if (StringUtils.isEmpty(rightTableAlias)){ - tableInfo.setRightTableAlias(rightTableName); - } else { - tableInfo.setRightTableAlias(rightTableAlias); - } + leftTbAlias = StringUtils.isEmpty(leftTbAlias) ? leftTbName : leftTbAlias; + rightTableAlias = StringUtils.isEmpty(rightTableAlias) ? rightTableName : rightTableAlias; - TableUtils.replaceJoinFieldRefTableName(joinNode.getCondition(), tableRef); - - tableInfo.setLeftIsTmpTable(leftTbisTmp); - tableInfo.setLeftIsSideTable(leftIsSide); + tableInfo.setLeftTableAlias(leftTbAlias); + tableInfo.setRightTableAlias(rightTableAlias); tableInfo.setRightIsSideTable(rightIsSide); tableInfo.setLeftNode(leftNode); tableInfo.setRightNode(rightNode); tableInfo.setJoinType(joinType); tableInfo.setCondition(joinNode.getCondition()); + TableUtils.replaceJoinFieldRefTableName(joinNode.getCondition(), fieldRef); - if(tableInfo.getLeftNode().getKind() != AS && needBuildTemp){ - extractTemporaryQuery(tableInfo.getLeftNode(), tableInfo.getLeftTableAlias(), (SqlBasicCall) parentWhere, - parentSelectList, queueInfo, joinFieldSet, tableRef); - }else { - SqlKind asNodeFirstKind = ((SqlBasicCall)tableInfo.getLeftNode()).operands[0].getKind(); - if(asNodeFirstKind == SELECT){ - queueInfo.offer(tableInfo.getLeftNode()); - tableInfo.setLeftNode(((SqlBasicCall)tableInfo.getLeftNode()).operands[1]); - } + //extract 需要查询的字段信息 + if(rightIsSide){ + extractJoinNeedSelectField(leftNode, rightNode, parentWhere, parentSelectList, parentGroupByList, tableRef, joinFieldSet, fieldRef, tableInfo); + } + + if(tableInfo.getLeftNode().getKind() != AS){ + return tableInfo; + } + + SqlKind asNodeFirstKind = ((SqlBasicCall)tableInfo.getLeftNode()).operands[0].getKind(); + if(asNodeFirstKind == SELECT){ + queueInfo.offer(tableInfo.getLeftNode()); + tableInfo.setLeftNode(((SqlBasicCall)tableInfo.getLeftNode()).operands[1]); } + return tableInfo; } + /** + * 获取join 之后需要查询的字段信息 + */ + public void extractJoinNeedSelectField(SqlNode leftNode, + SqlNode rightNode, + SqlNode parentWhere, + SqlNodeList parentSelectList, + SqlNodeList parentGroupByList, + Map tableRef, + Set> joinFieldSet, + Map fieldRef, + JoinInfo tableInfo){ + + Set extractSelectField = extractField(leftNode, parentWhere, parentSelectList, parentGroupByList, tableRef, joinFieldSet); + Set rightExtractSelectField = extractField(rightNode, parentWhere, parentSelectList, parentGroupByList, tableRef, joinFieldSet); + + //重命名right 中和 left 重名的 + Map leftTbSelectField = Maps.newHashMap(); + Map rightTbSelectField = Maps.newHashMap(); + String newTableName = tableInfo.getNewTableAlias(); + + for(String tmpField : extractSelectField){ + String[] tmpFieldSplit = StringUtils.split(tmpField, '.'); + leftTbSelectField.put(tmpFieldSplit[1], tmpFieldSplit[1]); + fieldRef.put(tmpField, TableUtils.buildTableField(newTableName, tmpFieldSplit[1])); + } + + for(String tmpField : rightExtractSelectField){ + String[] tmpFieldSplit = StringUtils.split(tmpField, '.'); + String originalFieldName = tmpFieldSplit[1]; + String targetFieldName = originalFieldName; + if(leftTbSelectField.containsKey(originalFieldName)){ + targetFieldName = ParseUtils.dealDuplicateFieldName(leftTbSelectField, originalFieldName); + } + + rightTbSelectField.put(originalFieldName, targetFieldName); + fieldRef.put(tmpField, TableUtils.buildTableField(newTableName, targetFieldName)); + } + + tableInfo.setLeftSelectFieldInfo(leftTbSelectField); + tableInfo.setRightSelectFieldInfo(rightTbSelectField); + } + + /** + * 指定的节点关联到的 select 中的字段和 where中的字段 + * @param sqlNode + * @param parentWhere + * @param parentSelectList + * @param parentGroupByList + * @param tableRef + * @param joinFieldSet + * @return + */ + public Set extractField(SqlNode sqlNode, + SqlNode parentWhere, + SqlNodeList parentSelectList, + SqlNodeList parentGroupByList, + Map tableRef, + Set> joinFieldSet){ + Set fromTableNameSet = Sets.newHashSet(); + TableUtils.getFromTableInfo(sqlNode, fromTableNameSet); + Set extractCondition = Sets.newHashSet(); + + extractWhereCondition(fromTableNameSet, (SqlBasicCall) parentWhere, extractCondition); + Set extractSelectField = extractSelectFields(parentSelectList, fromTableNameSet, tableRef); + Set fieldFromJoinCondition = extractSelectFieldFromJoinCondition(joinFieldSet, fromTableNameSet, tableRef); + + Set extractGroupByField = extractFieldFromGroupByList(parentGroupByList, fromTableNameSet, tableRef); + + extractSelectField.addAll(extractCondition); + extractSelectField.addAll(fieldFromJoinCondition); + extractSelectField.addAll(extractGroupByField); + + return extractSelectField; + } - //处理多层join - private Tuple2 dealNestJoin(SqlJoin joinNode, Set sideTableSet, - Queue queueInfo, SqlNode parentWhere, - SqlNodeList selectList, Set> joinFieldSet, - Map tableRef){ - SqlNode rightNode = joinNode.getRight(); - Tuple2 rightTableNameAndAlias = parseRightNode(rightNode, sideTableSet, queueInfo, parentWhere, selectList); - JoinInfo joinInfo = dealJoinNode(joinNode, sideTableSet, queueInfo, parentWhere, selectList, joinFieldSet, tableRef); + + /** + * 处理多层join + * 判断左节点是否需要创建临时查询 + * (1)右节点是维表 + * (2)左节点不是 as 节点 + */ + private JoinInfo dealNestJoin(SqlJoin joinNode, + Set sideTableSet, + Queue queueInfo, + SqlNode parentWhere, + SqlNodeList parentSelectList, + SqlNodeList parentGroupByList, + Set> joinFieldSet, + Map tableRef, + Map fieldRef){ + + SqlJoin leftJoinNode = (SqlJoin) joinNode.getLeft(); + SqlNode parentRightJoinNode = joinNode.getRight(); + SqlNode rightNode = leftJoinNode.getRight(); + Tuple2 rightTableNameAndAlias = parseRightNode(rightNode, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList); + Tuple2 parentRightJoinInfo = parseRightNode(parentRightJoinNode, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList); + boolean parentRightIsSide = checkIsSideTable(parentRightJoinInfo.f0, sideTableSet); + + JoinInfo joinInfo = dealJoinNode(leftJoinNode, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList, joinFieldSet, tableRef, fieldRef); String rightTableName = rightTableNameAndAlias.f0; boolean rightIsSide = checkIsSideTable(rightTableName, sideTableSet); - boolean needBuildTemp = false; + SqlBasicCall buildAs = TableUtils.buildAsNodeByJoinInfo(joinInfo, null, null); - if(!rightIsSide){ - //右表不是维表的情况 - }else{ - //右边表是维表需要重新构建左表的临时查询 - queueInfo.offer(joinInfo); - needBuildTemp = true; + if(rightIsSide){ + addSideInfoToExeQueue(queueInfo, joinInfo, joinNode, parentSelectList, parentGroupByList, parentWhere, tableRef); + } + + SqlNode newLeftNode = joinNode.getLeft(); + + if(newLeftNode.getKind() != AS && parentRightIsSide){ + + String leftTbAlias = buildAs.getOperands()[1].toString(); + extractTemporaryQuery(newLeftNode, leftTbAlias, (SqlBasicCall) parentWhere, + parentSelectList, queueInfo, joinFieldSet, tableRef, fieldRef); + + //替换leftNode 为新的查询 + joinNode.setLeft(buildAs); + replaceSelectAndWhereField(buildAs, leftJoinNode, tableRef, parentSelectList, parentGroupByList, parentWhere); } - //return Tuple2.of(needBuildTemp, TableUtils.buildAsNodeByJoinInfo(joinInfo, null, null)); - return Tuple2.of(needBuildTemp, joinInfo); + return joinInfo; } - private void extractTemporaryQuery(SqlNode node, String tableAlias, SqlBasicCall parentWhere, - SqlNodeList parentSelectList, Queue queueInfo, + /** + * 右边表是维表需要重新构建左表的临时查询 + * 并将joinInfo 添加到执行队列里面 + * @param queueInfo + * @param joinInfo + * @param joinNode + * @param parentSelectList + * @param parentGroupByList + * @param parentWhere + * @param tableRef + */ + public void addSideInfoToExeQueue(Queue queueInfo, + JoinInfo joinInfo, + SqlJoin joinNode, + SqlNodeList parentSelectList, + SqlNodeList parentGroupByList, + SqlNode parentWhere, + Map tableRef){ + //只处理维表 + if(!joinInfo.isRightIsSideTable()){ + return; + } + + SqlBasicCall buildAs = TableUtils.buildAsNodeByJoinInfo(joinInfo, null, null); + SqlNode leftJoinNode = joinNode.getLeft(); + queueInfo.offer(joinInfo); + //替换左表为新的表名称 + joinNode.setLeft(buildAs); + + replaceSelectAndWhereField(buildAs, leftJoinNode, tableRef, parentSelectList, parentGroupByList, parentWhere); + } + + /** + * 替换指定的查询和条件节点中的字段为新的字段 + * @param buildAs + * @param leftJoinNode + * @param tableRef + * @param parentSelectList + * @param parentGroupByList + * @param parentWhere + */ + public void replaceSelectAndWhereField(SqlBasicCall buildAs, + SqlNode leftJoinNode, + Map tableRef, + SqlNodeList parentSelectList, + SqlNodeList parentGroupByList, + SqlNode parentWhere){ + + String newLeftTableName = buildAs.getOperands()[1].toString(); + Set fromTableNameSet = Sets.newHashSet(); + TableUtils.getFromTableInfo(leftJoinNode, fromTableNameSet); + + for(String tbTmp : fromTableNameSet){ + tableRef.put(tbTmp, newLeftTableName); + } + + //替换select field 中的对应字段 + HashBiMap fieldReplaceRef = HashBiMap.create(); + for(SqlNode sqlNode : parentSelectList.getList()){ + for(String tbTmp : fromTableNameSet) { + TableUtils.replaceSelectFieldTable(sqlNode, tbTmp, newLeftTableName, fieldReplaceRef); + } + } + + //TODO 应该根据上面的查询字段的关联关系来替换 + //替换where 中的条件相关 + for(String tbTmp : fromTableNameSet){ + TableUtils.replaceWhereCondition(parentWhere, tbTmp, newLeftTableName, fieldReplaceRef); + } + + if(parentGroupByList != null){ + for(SqlNode sqlNode : parentGroupByList.getList()){ + for(String tbTmp : fromTableNameSet) { + TableUtils.replaceSelectFieldTable(sqlNode, tbTmp, newLeftTableName, fieldReplaceRef); + } + } + } + + } + + /** + * 抽取出中间查询表 + * @param node + * @param tableAlias + * @param parentWhere + * @param parentSelectList + * @param queueInfo + * @param joinFieldSet + * @param tableRef + * @return 源自段和新生成字段之间的映射关系 + */ + private void extractTemporaryQuery(SqlNode node, String tableAlias, + SqlBasicCall parentWhere, + SqlNodeList parentSelectList, + Queue queueInfo, Set> joinFieldSet, - Map tableRef){ + Map tableRef, + Map fieldRef){ try{ //父一级的where 条件中如果只和临时查询相关的条件都截取进来 Set fromTableNameSet = Sets.newHashSet(); @@ -246,8 +402,13 @@ private void extractTemporaryQuery(SqlNode node, String tableAlias, SqlBasicCall } Set extractSelectField = extractSelectFields(parentSelectList, fromTableNameSet, tableRef); - Set fieldFromJoinCondition = extractSelectFieldFromJoinCondition(joinFieldSet, fromTableNameSet); - String extractSelectFieldStr = buildSelectNode(extractSelectField, fieldFromJoinCondition); + Set fieldFromJoinCondition = extractSelectFieldFromJoinCondition(joinFieldSet, fromTableNameSet, tableRef); + Set newFields = buildSelectNode(extractSelectField, fieldFromJoinCondition); + String extractSelectFieldStr = StringUtils.join(newFields, ','); + + Map oldRefNewField = buildTmpTableFieldRefOriField(newFields, tableAlias); + fieldRef.putAll(oldRefNewField); + String extractConditionStr = buildCondition(extractCondition); String tmpSelectSql = String.format(SELECT_TEMP_SQL, @@ -255,14 +416,33 @@ private void extractTemporaryQuery(SqlNode node, String tableAlias, SqlBasicCall node.toString(), extractConditionStr); - SqlParser sqlParser = SqlParser.create(tmpSelectSql, CalciteConfig.MYSQL_LEX_CONFIG); - SqlNode sqlNode = sqlParser.parseStmt(); + FlinkPlannerImpl flinkPlanner = FlinkPlanner.getFlinkPlanner(); + SqlNode sqlNode = flinkPlanner.parse(tmpSelectSql); + SqlBasicCall sqlBasicCall = buildAsSqlNode(tableAlias, sqlNode); queueInfo.offer(sqlBasicCall); + //替换select中的表结构 + HashBiMap fieldReplaceRef = HashBiMap.create(); + for(SqlNode tmpSelect : parentSelectList.getList()){ + for(String tbTmp : fromTableNameSet) { + TableUtils.replaceSelectFieldTable(tmpSelect, tbTmp, tableAlias, fieldReplaceRef); + } + } + + //替换where 中的条件相关 + for(String tbTmp : fromTableNameSet){ + TableUtils.replaceWhereCondition(parentWhere, tbTmp, tableAlias, fieldReplaceRef); + } + + for(String tbTmp : fromTableNameSet){ + tableRef.put(tbTmp, tableAlias); + } + System.out.println("-------build temporary query-----------"); System.out.println(tmpSelectSql); System.out.println("---------------------------------------"); + }catch (Exception e){ e.printStackTrace(); throw new RuntimeException(e); @@ -271,7 +451,6 @@ private void extractTemporaryQuery(SqlNode node, String tableAlias, SqlBasicCall /** * 抽取上层需用使用到的字段 - * 由于where字段已经抽取到上一层了所以不用查询出来 * @param parentSelectList * @param fromTableNameSet * @return @@ -287,12 +466,36 @@ private Set extractSelectFields(SqlNodeList parentSelectList, return extractFieldList; } - private Set extractSelectFieldFromJoinCondition(Set> joinFieldSet, Set fromTableNameSet){ + private Set extractSelectFieldFromJoinCondition(Set> joinFieldSet, + Set fromTableNameSet, + Map tableRef){ Set extractFieldList = Sets.newHashSet(); for(Tuple2 field : joinFieldSet){ if(fromTableNameSet.contains(field.f0)){ extractFieldList.add(field.f0 + "." + field.f1); } + + if(tableRef.containsKey(field.f0)){ + if(fromTableNameSet.contains(tableRef.get(field.f0))){ + extractFieldList.add(tableRef.get(field.f0) + "." + field.f1); + } + } + } + + return extractFieldList; + } + + private Set extractFieldFromGroupByList(SqlNodeList parentGroupByList, + Set fromTableNameSet, + Map tableRef){ + + if(parentGroupByList == null){ + return Sets.newHashSet(); + } + + Set extractFieldList = Sets.newHashSet(); + for(SqlNode selectNode : parentGroupByList.getList()){ + extractSelectField(selectNode, extractFieldList, fromTableNameSet, tableRef); } return extractFieldList; @@ -304,6 +507,10 @@ private Set extractSelectFieldFromJoinCondition(Set> joinFieldSet){ + if (null == condition || condition.getKind() == LITERAL) { + return; + } + SqlKind joinKind = condition.getKind(); if( joinKind == AND || joinKind == EQUALS ){ extractJoinField(((SqlBasicCall)condition).operands[0], joinFieldSet); @@ -409,12 +616,12 @@ private void extractSelectField(SqlNode selectNode, private Tuple2 parseRightNode(SqlNode sqlNode, Set sideTableSet, Queue queueInfo, - SqlNode parentWhere, SqlNodeList selectList) { + SqlNode parentWhere, SqlNodeList selectList, SqlNodeList parentGroupByList) { Tuple2 tabName = new Tuple2<>("", ""); if(sqlNode.getKind() == IDENTIFIER){ tabName.f0 = sqlNode.toString(); }else{ - AliasInfo aliasInfo = (AliasInfo)sideSQLParser.parseSql(sqlNode, sideTableSet, queueInfo, parentWhere, selectList); + AliasInfo aliasInfo = (AliasInfo)sideSQLParser.parseSql(sqlNode, sideTableSet, queueInfo, parentWhere, selectList, parentGroupByList); tabName.f0 = aliasInfo.getName(); tabName.f1 = aliasInfo.getAlias(); } @@ -447,14 +654,39 @@ public String buildCondition(List conditionList){ return " where " + StringUtils.join(conditionList, " AND "); } - public String buildSelectNode(Set extractSelectField, Set joinFieldSet){ + /** + * 构建抽取表的查询字段信息 + * 包括去除重复字段,名称相同的取别名 + * @param extractSelectField + * @param joinFieldSet + * @return + */ + public Set buildSelectNode(Set extractSelectField, Set joinFieldSet){ if(CollectionUtils.isEmpty(extractSelectField)){ throw new RuntimeException("no field is used"); } - Sets.SetView view = Sets.union(extractSelectField, joinFieldSet); + Sets.SetView view = Sets.union(extractSelectField, joinFieldSet); + Set newFieldSet = Sets.newHashSet(); + //为相同的列取别名 + HashBiMap refFieldMap = HashBiMap.create(); + for(String field : view){ + String[] fieldInfo = StringUtils.split(field, '.'); + String aliasName = fieldInfo[1]; + StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append(field); + if(refFieldMap.inverse().get(aliasName) != null){ + aliasName = ParseUtils.dealDuplicateFieldName(refFieldMap, aliasName); + stringBuilder.append(" as ") + .append(aliasName); + } + + refFieldMap.put(field, aliasName); + + newFieldSet.add(stringBuilder.toString()); + } - return StringUtils.join(view, ","); + return newFieldSet; } private boolean checkIsSideTable(String tableName, Set sideTableList){ @@ -474,6 +706,38 @@ private SqlBasicCall buildAsSqlNode(String internalTableName, SqlNode newSource) return new SqlBasicCall(operator, sqlNodes, sqlParserPos); } + /** + * 获取where中和指定表有关联的字段 + * @param fromTableNameSet + * @param parentWhere + * @param extractCondition + */ + private void extractWhereCondition(Set fromTableNameSet, SqlBasicCall parentWhere, Set extractCondition){ + + if(parentWhere == null){ + return; + } + + SqlKind kind = parentWhere.getKind(); + if(kind == AND){ + extractWhereCondition(fromTableNameSet, (SqlBasicCall) parentWhere.getOperands()[0], extractCondition); + extractWhereCondition(fromTableNameSet, (SqlBasicCall) parentWhere.getOperands()[1], extractCondition); + } else { + + Set fieldInfos = Sets.newHashSet(); + TableUtils.getConditionRefTable(parentWhere, fieldInfos); + fieldInfos.forEach(fieldInfo -> { + String[] splitInfo = StringUtils.split(fieldInfo, "."); + if(splitInfo.length == 2 && fromTableNameSet.contains(splitInfo[0])){ + extractCondition.add(fieldInfo); + } + }); + + } + + + } + /** * 检查关联的where 条件中的判断是否可以下移到新构建的子查询 @@ -504,8 +768,18 @@ private boolean checkAndRemoveWhereCondition(Set fromTableNameSet, return false; } else { + //条件表达式,如果该条件关联的表都是指定的表则移除 + Set fieldInfos = Sets.newHashSet(); + TableUtils.getConditionRefTable(parentWhere, fieldInfos); Set conditionRefTableNameSet = Sets.newHashSet(); - TableUtils.getConditionRefTable(parentWhere, conditionRefTableNameSet); + + fieldInfos.forEach(fieldInfo -> { + String[] splitInfo = StringUtils.split(fieldInfo, "."); + if(splitInfo.length == 2){ + conditionRefTableNameSet.add(splitInfo[0]); + } + }); + if(fromTableNameSet.containsAll(conditionRefTableNameSet)){ return true; @@ -574,6 +848,25 @@ private SqlIdentifier checkAndReplaceJoinCondition(SqlNode node, Map buildTmpTableFieldRefOriField(Set fieldSet, String newTableAliasName){ + Map refInfo = Maps.newConcurrentMap(); + for(String field : fieldSet){ + String[] fields = StringUtils.splitByWholeSeparator(field, "as"); + String oldKey = field; + String[] oldFieldInfo = StringUtils.splitByWholeSeparator(fields[0], "."); + String oldFieldName = oldFieldInfo.length == 2 ? oldFieldInfo[1] : oldFieldInfo[0]; + String newKey = fields.length == 2 ? newTableAliasName + "." + fields[1] : + newTableAliasName + "." + oldFieldName; + refInfo.put(oldKey, newKey); + } + + return refInfo; + } } diff --git a/core/src/main/java/com/dtstack/flink/sql/side/SideInfo.java b/core/src/main/java/com/dtstack/flink/sql/side/SideInfo.java index df41e1663..029c86e25 100644 --- a/core/src/main/java/com/dtstack/flink/sql/side/SideInfo.java +++ b/core/src/main/java/com/dtstack/flink/sql/side/SideInfo.java @@ -55,6 +55,8 @@ public abstract class SideInfo implements Serializable{ protected String sideSelectFields = ""; + protected Map sideSelectFieldsType = Maps.newHashMap(); + protected JoinType joinType; //key:Returns the value of the position, value: the ref field index​in the input table @@ -84,15 +86,17 @@ public void parseSelectFields(JoinInfo joinInfo){ String sideTableName = joinInfo.getSideTableName(); String nonSideTableName = joinInfo.getNonSideTable(); List fields = Lists.newArrayList(); + int sideTableFieldIndex = 0; - int sideIndex = 0; for( int i=0; i getSideFieldNameIndex() { public void setSideFieldNameIndex(Map sideFieldNameIndex) { this.sideFieldNameIndex = sideFieldNameIndex; } + + public Map getSideSelectFieldsType() { + return sideSelectFieldsType; + } + + public void setSideSelectFieldsType(Map sideSelectFieldsType) { + this.sideSelectFieldsType = sideSelectFieldsType; + } + + public String getSelectSideFieldType(int index){ + return sideSelectFieldsType.get(index); + } } diff --git a/core/src/main/java/com/dtstack/flink/sql/side/SidePredicatesParser.java b/core/src/main/java/com/dtstack/flink/sql/side/SidePredicatesParser.java index 50103a9f5..0902bf39f 100644 --- a/core/src/main/java/com/dtstack/flink/sql/side/SidePredicatesParser.java +++ b/core/src/main/java/com/dtstack/flink/sql/side/SidePredicatesParser.java @@ -18,7 +18,7 @@ package com.dtstack.flink.sql.side; -import com.dtstack.flink.sql.config.CalciteConfig; +import com.dtstack.flink.sql.parser.FlinkPlanner; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import org.apache.calcite.sql.SqlBasicCall; @@ -30,11 +30,10 @@ import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.SqlSelect; import org.apache.calcite.sql.parser.SqlParseException; -import org.apache.calcite.sql.parser.SqlParser; +import org.apache.flink.table.calcite.FlinkPlannerImpl; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import static org.apache.calcite.sql.SqlKind.*; @@ -47,8 +46,8 @@ */ public class SidePredicatesParser { public void fillPredicatesForSideTable(String exeSql, Map sideTableMap) throws SqlParseException { - SqlParser sqlParser = SqlParser.create(exeSql, CalciteConfig.MYSQL_LEX_CONFIG); - SqlNode sqlNode = sqlParser.parseStmt(); + FlinkPlannerImpl flinkPlanner = FlinkPlanner.getFlinkPlanner(); + SqlNode sqlNode = flinkPlanner.parse(exeSql); parseSql(sqlNode, sideTableMap, Maps.newHashMap()); } diff --git a/core/src/main/java/com/dtstack/flink/sql/side/SideSQLParser.java b/core/src/main/java/com/dtstack/flink/sql/side/SideSQLParser.java index 061fe52a2..f37c3f78c 100644 --- a/core/src/main/java/com/dtstack/flink/sql/side/SideSQLParser.java +++ b/core/src/main/java/com/dtstack/flink/sql/side/SideSQLParser.java @@ -20,43 +20,29 @@ package com.dtstack.flink.sql.side; -import com.dtstack.flink.sql.config.CalciteConfig; +import com.dtstack.flink.sql.parser.FlinkPlanner; import com.dtstack.flink.sql.util.TableUtils; -import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Queues; import com.google.common.collect.Sets; -import org.apache.calcite.sql.JoinType; -import org.apache.calcite.sql.SqlAsOperator; import org.apache.calcite.sql.SqlBasicCall; -import org.apache.calcite.sql.SqlBinaryOperator; -import org.apache.calcite.sql.SqlDataTypeSpec; import org.apache.calcite.sql.SqlIdentifier; import org.apache.calcite.sql.SqlInsert; import org.apache.calcite.sql.SqlJoin; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.SqlLiteral; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlNodeList; -import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.SqlOrderBy; import org.apache.calcite.sql.SqlSelect; import org.apache.calcite.sql.SqlWith; import org.apache.calcite.sql.SqlWithItem; -import org.apache.calcite.sql.fun.SqlCase; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParseException; -import org.apache.calcite.sql.parser.SqlParser; -import org.apache.calcite.sql.parser.SqlParserPos; -import org.apache.commons.collections.CollectionUtils; -import org.apache.commons.lang3.StringUtils; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.table.api.Table; +import org.apache.flink.table.calcite.FlinkPlannerImpl; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.List; import java.util.Map; import java.util.Queue; import java.util.Set; @@ -82,57 +68,30 @@ public Queue getExeQueue(String exeSql, Set sideTableSet) throws LOG.info(exeSql); Queue queueInfo = Queues.newLinkedBlockingQueue(); - SqlParser sqlParser = SqlParser.create(exeSql, CalciteConfig.MYSQL_LEX_CONFIG); - SqlNode sqlNode = sqlParser.parseStmt(); + FlinkPlannerImpl flinkPlanner = FlinkPlanner.getFlinkPlanner(); + SqlNode sqlNode = flinkPlanner.parse(exeSql); - parseSql(sqlNode, sideTableSet, queueInfo, null, null); + parseSql(sqlNode, sideTableSet, queueInfo, null, null, null); queueInfo.offer(sqlNode); return queueInfo; } - private void checkAndReplaceMultiJoin(SqlNode sqlNode, Set sideTableSet) { - SqlKind sqlKind = sqlNode.getKind(); - switch (sqlKind) { - case WITH: { - SqlWith sqlWith = (SqlWith) sqlNode; - SqlNodeList sqlNodeList = sqlWith.withList; - for (SqlNode withAsTable : sqlNodeList) { - SqlWithItem sqlWithItem = (SqlWithItem) withAsTable; - checkAndReplaceMultiJoin(sqlWithItem.query, sideTableSet); - } - checkAndReplaceMultiJoin(sqlWith.body, sideTableSet); - break; - } - case INSERT: - SqlNode sqlSource = ((SqlInsert) sqlNode).getSource(); - checkAndReplaceMultiJoin(sqlSource, sideTableSet); - break; - case SELECT: - SqlNode sqlFrom = ((SqlSelect) sqlNode).getFrom(); - if (sqlFrom.getKind() != IDENTIFIER) { - checkAndReplaceMultiJoin(sqlFrom, sideTableSet); - } - break; - case JOIN: - convertSideJoinToNewQuery((SqlJoin) sqlNode, sideTableSet); - break; - case AS: - SqlNode info = ((SqlBasicCall) sqlNode).getOperands()[0]; - if (info.getKind() != IDENTIFIER) { - checkAndReplaceMultiJoin(info, sideTableSet); - } - break; - case UNION: - SqlNode unionLeft = ((SqlBasicCall) sqlNode).getOperands()[0]; - SqlNode unionRight = ((SqlBasicCall) sqlNode).getOperands()[1]; - checkAndReplaceMultiJoin(unionLeft, sideTableSet); - checkAndReplaceMultiJoin(unionRight, sideTableSet); - break; - } - } - - public Object parseSql(SqlNode sqlNode, Set sideTableSet, Queue queueInfo, SqlNode parentWhere, SqlNodeList parentSelectList){ + /** + * 解析 sql 根据维表 join关系重新组装新的sql + * @param sqlNode + * @param sideTableSet + * @param queueInfo + * @param parentWhere + * @param parentSelectList + * @return + */ + public Object parseSql(SqlNode sqlNode, + Set sideTableSet, + Queue queueInfo, + SqlNode parentWhere, + SqlNodeList parentSelectList, + SqlNodeList parentGroupByList){ SqlKind sqlKind = sqlNode.getKind(); switch (sqlKind){ case WITH: { @@ -140,22 +99,23 @@ public Object parseSql(SqlNode sqlNode, Set sideTableSet, Queue SqlNodeList sqlNodeList = sqlWith.withList; for (SqlNode withAsTable : sqlNodeList) { SqlWithItem sqlWithItem = (SqlWithItem) withAsTable; - parseSql(sqlWithItem.query, sideTableSet, queueInfo, parentWhere, parentSelectList); + parseSql(sqlWithItem.query, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList); queueInfo.add(sqlWithItem); } - parseSql(sqlWith.body, sideTableSet, queueInfo, parentWhere, parentSelectList); + parseSql(sqlWith.body, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList); break; } case INSERT: SqlNode sqlSource = ((SqlInsert)sqlNode).getSource(); - return parseSql(sqlSource, sideTableSet, queueInfo, parentWhere, parentSelectList); + return parseSql(sqlSource, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList); case SELECT: SqlNode sqlFrom = ((SqlSelect)sqlNode).getFrom(); SqlNode sqlWhere = ((SqlSelect)sqlNode).getWhere(); SqlNodeList selectList = ((SqlSelect)sqlNode).getSelectList(); + SqlNodeList groupByList = ((SqlSelect) sqlNode).getGroup(); if(sqlFrom.getKind() != IDENTIFIER){ - Object result = parseSql(sqlFrom, sideTableSet, queueInfo, sqlWhere, selectList); + Object result = parseSql(sqlFrom, sideTableSet, queueInfo, sqlWhere, selectList, groupByList); if(result instanceof JoinInfo){ return TableUtils.dealSelectResultWithJoinInfo((JoinInfo) result, (SqlSelect) sqlNode, queueInfo); }else if(result instanceof AliasInfo){ @@ -175,7 +135,9 @@ public Object parseSql(SqlNode sqlNode, Set sideTableSet, Queue JoinNodeDealer joinNodeDealer = new JoinNodeDealer(this); Set> joinFieldSet = Sets.newHashSet(); Map tableRef = Maps.newHashMap(); - return joinNodeDealer.dealJoinNode((SqlJoin) sqlNode, sideTableSet, queueInfo, parentWhere, parentSelectList, joinFieldSet, tableRef); + Map fieldRef = Maps.newHashMap(); + return joinNodeDealer.dealJoinNode((SqlJoin) sqlNode, sideTableSet, queueInfo, + parentWhere, parentSelectList, parentGroupByList, joinFieldSet, tableRef, fieldRef); case AS: SqlNode info = ((SqlBasicCall)sqlNode).getOperands()[0]; SqlNode alias = ((SqlBasicCall) sqlNode).getOperands()[1]; @@ -184,7 +146,7 @@ public Object parseSql(SqlNode sqlNode, Set sideTableSet, Queue if(info.getKind() == IDENTIFIER){ infoStr = info.toString(); } else { - infoStr = parseSql(info, sideTableSet, queueInfo, parentWhere, parentSelectList).toString(); + infoStr = parseSql(info, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList).toString(); } AliasInfo aliasInfo = new AliasInfo(); @@ -197,36 +159,20 @@ public Object parseSql(SqlNode sqlNode, Set sideTableSet, Queue SqlNode unionLeft = ((SqlBasicCall)sqlNode).getOperands()[0]; SqlNode unionRight = ((SqlBasicCall)sqlNode).getOperands()[1]; - parseSql(unionLeft, sideTableSet, queueInfo, parentWhere, parentSelectList); - parseSql(unionRight, sideTableSet, queueInfo, parentWhere, parentSelectList); + parseSql(unionLeft, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList); + parseSql(unionRight, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList); break; case ORDER_BY: SqlOrderBy sqlOrderBy = (SqlOrderBy) sqlNode; - parseSql(sqlOrderBy.query, sideTableSet, queueInfo, parentWhere, parentSelectList); + parseSql(sqlOrderBy.query, sideTableSet, queueInfo, parentWhere, parentSelectList, parentGroupByList); + + case LITERAL: + return LITERAL.toString(); } return ""; } - private AliasInfo getSqlNodeAliasInfo(SqlNode sqlNode) { - SqlNode info = ((SqlBasicCall) sqlNode).getOperands()[0]; - SqlNode alias = ((SqlBasicCall) sqlNode).getOperands()[1]; - String infoStr = info.getKind() == IDENTIFIER ? info.toString() : null; - AliasInfo aliasInfo = new AliasInfo(); - aliasInfo.setName(infoStr); - aliasInfo.setAlias(alias.toString()); - return aliasInfo; - } - - /** - * 将和维表关联的join 替换为一个新的查询 - * @param sqlNode - * @param sideTableSet - */ - private void convertSideJoinToNewQuery(SqlJoin sqlNode, Set sideTableSet) { - checkAndReplaceMultiJoin(sqlNode.getLeft(), sideTableSet); - checkAndReplaceMultiJoin(sqlNode.getRight(), sideTableSet); - } public void setLocalTableCache(Map localTableCache) { diff --git a/core/src/main/java/com/dtstack/flink/sql/side/SideSqlExec.java b/core/src/main/java/com/dtstack/flink/sql/side/SideSqlExec.java index f5e18345d..f90138b2a 100644 --- a/core/src/main/java/com/dtstack/flink/sql/side/SideSqlExec.java +++ b/core/src/main/java/com/dtstack/flink/sql/side/SideSqlExec.java @@ -43,10 +43,7 @@ import com.dtstack.flink.sql.util.ParseUtils; import com.dtstack.flink.sql.util.TableUtils; import com.google.common.base.Preconditions; -import com.google.common.collect.HashBasedTable; -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; -import com.google.common.collect.Sets; +import com.google.common.collect.*; import org.apache.calcite.sql.SqlAsOperator; import org.apache.calcite.sql.SqlBasicCall; import org.apache.calcite.sql.SqlDataTypeSpec; @@ -99,8 +96,12 @@ public class SideSqlExec { private Map localTableCache = Maps.newHashMap(); - public void exec(String sql, Map sideTableMap, StreamTableEnvironment tableEnv, - Map tableCache, StreamQueryConfig queryConfig, CreateTmpTableParser.SqlParserResult createView) throws Exception { + public void exec(String sql, + Map sideTableMap, + StreamTableEnvironment tableEnv, + Map tableCache, + StreamQueryConfig queryConfig, + CreateTmpTableParser.SqlParserResult createView) throws Exception { if(localSqlPluginPath == null){ throw new RuntimeException("need to set localSqlPluginPath"); } @@ -123,24 +124,11 @@ public void exec(String sql, Map sideTableMap, StreamTabl Queue exeQueue = sideSQLParser.getExeQueue(sql, sideTableMap.keySet()); Object pollObj = null; - //need clean - boolean preIsSideJoin = false; - List replaceInfoList = Lists.newArrayList(); - while((pollObj = exeQueue.poll()) != null){ if(pollObj instanceof SqlNode){ SqlNode pollSqlNode = (SqlNode) pollObj; - if(preIsSideJoin){ - preIsSideJoin = false; - List fieldNames = null; - for(FieldReplaceInfo replaceInfo : replaceInfoList){ - fieldNames = Lists.newArrayList(); - replaceFieldName(pollSqlNode, replaceInfo); - addAliasForFieldNode(pollSqlNode, fieldNames, replaceInfo.getMappingTable()); - } - } if(pollSqlNode.getKind() == INSERT){ System.out.println("----------real exec sql-----------" ); @@ -151,7 +139,7 @@ public void exec(String sql, Map sideTableMap, StreamTabl } }else if(pollSqlNode.getKind() == AS){ - dealAsSourceTable(tableEnv, pollSqlNode, tableCache, replaceInfoList); + dealAsSourceTable(tableEnv, pollSqlNode, tableCache); } else if (pollSqlNode.getKind() == WITH_ITEM) { SqlWithItem sqlWithItem = (SqlWithItem) pollSqlNode; @@ -180,8 +168,7 @@ public void exec(String sql, Map sideTableMap, StreamTabl }else if (pollObj instanceof JoinInfo){ System.out.println("----------exec join info----------"); System.out.println(pollObj.toString()); - preIsSideJoin = true; - joinFun(pollObj, localTableCache, sideTableMap, tableEnv, replaceInfoList); + joinFun(pollObj, localTableCache, sideTableMap, tableEnv); } } @@ -221,66 +208,6 @@ private FieldReplaceInfo parseAsQuery(SqlBasicCall asSqlNode, Map } - /** - * 添加字段别名 - * @param pollSqlNode - * @param fieldList - * @param mappingTable - */ - private void addAliasForFieldNode(SqlNode pollSqlNode, List fieldList, HashBasedTable mappingTable) { - SqlKind sqlKind = pollSqlNode.getKind(); - switch (sqlKind) { - case INSERT: - SqlNode source = ((SqlInsert) pollSqlNode).getSource(); - addAliasForFieldNode(source, fieldList, mappingTable); - break; - case AS: - addAliasForFieldNode(((SqlBasicCall) pollSqlNode).getOperands()[0], fieldList, mappingTable); - break; - case SELECT: - SqlNodeList selectList = ((SqlSelect) pollSqlNode).getSelectList(); - selectList.getList().forEach(node -> { - if (node.getKind() == IDENTIFIER) { - SqlIdentifier sqlIdentifier = (SqlIdentifier) node; - if (sqlIdentifier.names.size() == 1) { - return; - } - // save real field - String fieldName = sqlIdentifier.names.get(1); - if (!fieldName.endsWith("0") || fieldName.endsWith("0") && mappingTable.columnMap().containsKey(fieldName)) { - fieldList.add(fieldName); - } - - } - }); - for (int i = 0; i < selectList.getList().size(); i++) { - SqlNode node = selectList.get(i); - if (node.getKind() == IDENTIFIER) { - SqlIdentifier sqlIdentifier = (SqlIdentifier) node; - if (sqlIdentifier.names.size() == 1) { - return; - } - String name = sqlIdentifier.names.get(1); - // avoid real field pv0 convert pv - if (name.endsWith("0") && !fieldList.contains(name) && !fieldList.contains(name.substring(0, name.length() - 1))) { - SqlOperator operator = new SqlAsOperator(); - SqlParserPos sqlParserPos = new SqlParserPos(0, 0); - - SqlIdentifier sqlIdentifierAlias = new SqlIdentifier(name.substring(0, name.length() - 1), null, sqlParserPos); - SqlNode[] sqlNodes = new SqlNode[2]; - sqlNodes[0] = sqlIdentifier; - sqlNodes[1] = sqlIdentifierAlias; - SqlBasicCall sqlBasicCall = new SqlBasicCall(operator, sqlNodes, sqlParserPos); - - selectList.set(i, sqlBasicCall); - } - } - } - break; - } - } - - public AliasInfo parseASNode(SqlNode sqlNode) throws SqlParseException { SqlKind sqlKind = sqlNode.getKind(); if(sqlKind != AS){ @@ -297,15 +224,17 @@ public AliasInfo parseASNode(SqlNode sqlNode) throws SqlParseException { return aliasInfo; } - public RowTypeInfo buildOutRowTypeInfo(List sideJoinFieldInfo, HashBasedTable mappingTable) { + public RowTypeInfo buildOutRowTypeInfo(List sideJoinFieldInfo, + HashBasedTable mappingTable) { TypeInformation[] sideOutTypes = new TypeInformation[sideJoinFieldInfo.size()]; String[] sideOutNames = new String[sideJoinFieldInfo.size()]; for (int i = 0; i < sideJoinFieldInfo.size(); i++) { FieldInfo fieldInfo = sideJoinFieldInfo.get(i); String tableName = fieldInfo.getTable(); String fieldName = fieldInfo.getFieldName(); - String mappingFieldName = ParseUtils.dealDuplicateFieldName(mappingTable, fieldName); - mappingTable.put(tableName, fieldName, mappingFieldName); + + String mappingFieldName = mappingTable.get(tableName, fieldName); + Preconditions.checkNotNull(mappingFieldName, fieldInfo + " not mapping any field! it may be frame bug"); sideOutTypes[i] = fieldInfo.getTypeInformation(); sideOutNames[i] = mappingFieldName; @@ -337,184 +266,9 @@ private TypeInformation convertTimeAttributeType(TypeInformation typeInformation return typeInformation; } - //需要考虑更多的情况 - private void replaceFieldName(SqlNode sqlNode, FieldReplaceInfo replaceInfo) { - SqlKind sqlKind = sqlNode.getKind(); - switch (sqlKind) { - case INSERT: - SqlNode sqlSource = ((SqlInsert) sqlNode).getSource(); - replaceFieldName(sqlSource, replaceInfo); - break; - case AS: - SqlNode asNode = ((SqlBasicCall) sqlNode).getOperands()[0]; - replaceFieldName(asNode, replaceInfo); - break; - case SELECT: - SqlSelect sqlSelect = (SqlSelect) filterNodeWithTargetName(sqlNode, replaceInfo.getTargetTableName()); - if(sqlSelect == null){ - return; - } - SqlNode sqlSource1 = sqlSelect.getFrom(); - if(sqlSource1.getKind() == AS){ - String tableName = ((SqlBasicCall)sqlSource1).getOperands()[0].toString(); - if(tableName.equalsIgnoreCase(replaceInfo.getTargetTableName())){ - SqlNodeList sqlSelectList = sqlSelect.getSelectList(); - SqlNode whereNode = sqlSelect.getWhere(); - SqlNodeList sqlGroup = sqlSelect.getGroup(); - - //TODO 暂时不处理having - SqlNode sqlHaving = sqlSelect.getHaving(); - - List newSelectNodeList = Lists.newArrayList(); - for( int i=0; i replaceNodeList = replaceSelectStarFieldName(selectNode, replaceInfo); - newSelectNodeList.addAll(replaceNodeList); - continue; - } - - SqlNode replaceNode = replaceSelectFieldName(selectNode, replaceInfo); - if(replaceNode == null){ - continue; - } - - //sqlSelectList.set(i, replaceNode); - newSelectNodeList.add(replaceNode); - } - SqlNodeList newSelectList = new SqlNodeList(newSelectNodeList, sqlSelectList.getParserPosition()); - sqlSelect.setSelectList(newSelectList); - - //where - if(whereNode != null){ - SqlNode[] sqlNodeList = ((SqlBasicCall)whereNode).getOperands(); - for(int i =0; i localTableCache, String table } if(table == null){ - throw new RuntimeException("not register table " + tableName); + throw new RuntimeException("not register table " + tableAlias); } return table; } - private List replaceSelectStarFieldName(SqlNode selectNode, FieldReplaceInfo replaceInfo){ - SqlIdentifier sqlIdentifier = (SqlIdentifier) selectNode; - List sqlNodes = Lists.newArrayList(); - if(sqlIdentifier.isStar()){//处理 [* or table.*] - int identifierSize = sqlIdentifier.names.size(); - Collection columns = null; - if(identifierSize == 1){ - columns = replaceInfo.getMappingTable().values(); - }else{ - columns = replaceInfo.getMappingTable().row(sqlIdentifier.names.get(0)).values(); - } - - for(String colAlias : columns){ - SqlParserPos sqlParserPos = new SqlParserPos(0, 0); - List columnInfo = Lists.newArrayList(); - columnInfo.add(replaceInfo.getTargetTableAlias()); - columnInfo.add(colAlias); - SqlIdentifier sqlIdentifierAlias = new SqlIdentifier(columnInfo, sqlParserPos); - sqlNodes.add(sqlIdentifierAlias); - } - - return sqlNodes; - }else{ - throw new RuntimeException("is not a star select field." + selectNode); - } - } - - private SqlNode replaceSelectFieldName(SqlNode selectNode, FieldReplaceInfo replaceInfo) { - if (selectNode.getKind() == AS) { - SqlNode leftNode = ((SqlBasicCall) selectNode).getOperands()[0]; - SqlNode replaceNode = replaceSelectFieldName(leftNode, replaceInfo); - if (replaceNode != null) { - ((SqlBasicCall) selectNode).getOperands()[0] = replaceNode; - } - - return selectNode; - }else if(selectNode.getKind() == IDENTIFIER){ - SqlIdentifier sqlIdentifier = (SqlIdentifier) selectNode; - - if(sqlIdentifier.names.size() == 1){ - return selectNode; - } - - //Same level mappingTable - String mappingFieldName = replaceInfo.getTargetFieldName(sqlIdentifier.getComponent(0).getSimple(), sqlIdentifier.getComponent(1).getSimple()); - if (mappingFieldName == null) { - throw new RuntimeException("can't find mapping fieldName:" + selectNode.toString() ); - } - - sqlIdentifier = sqlIdentifier.setName(0, replaceInfo.getTargetTableAlias()); - sqlIdentifier = sqlIdentifier.setName(1, mappingFieldName); - return sqlIdentifier; - }else if(selectNode.getKind() == LITERAL || selectNode.getKind() == LITERAL_CHAIN){//字面含义 - return selectNode; - }else if( AGGREGATE.contains(selectNode.getKind()) - || AVG_AGG_FUNCTIONS.contains(selectNode.getKind()) - || COMPARISON.contains(selectNode.getKind()) - || selectNode.getKind() == OTHER_FUNCTION - || selectNode.getKind() == DIVIDE - || selectNode.getKind() == CAST - || selectNode.getKind() == TRIM - || selectNode.getKind() == TIMES - || selectNode.getKind() == PLUS - || selectNode.getKind() == NOT_IN - || selectNode.getKind() == OR - || selectNode.getKind() == AND - || selectNode.getKind() == MINUS - || selectNode.getKind() == TUMBLE - || selectNode.getKind() == TUMBLE_START - || selectNode.getKind() == TUMBLE_END - || selectNode.getKind() == SESSION - || selectNode.getKind() == SESSION_START - || selectNode.getKind() == SESSION_END - || selectNode.getKind() == HOP - || selectNode.getKind() == HOP_START - || selectNode.getKind() == HOP_END - || selectNode.getKind() == BETWEEN - || selectNode.getKind() == IS_NULL - || selectNode.getKind() == IS_NOT_NULL - || selectNode.getKind() == CONTAINS - || selectNode.getKind() == TIMESTAMP_ADD - || selectNode.getKind() == TIMESTAMP_DIFF - || selectNode.getKind() == LIKE - - ){ - SqlBasicCall sqlBasicCall = (SqlBasicCall) selectNode; - for(int i=0; i getConditionFields(SqlNode conditionNode, String specifyTabl protected void dealAsSourceTable(StreamTableEnvironment tableEnv, SqlNode pollSqlNode, - Map tableCache, - List replaceInfoList) throws SqlParseException { + Map tableCache) throws SqlParseException { AliasInfo aliasInfo = parseASNode(pollSqlNode); if (localTableCache.containsKey(aliasInfo.getName())) { @@ -748,19 +364,13 @@ protected void dealAsSourceTable(StreamTableEnvironment tableEnv, Set fromTableNameSet = Sets.newHashSet(); SqlNode fromNode = ((SqlBasicCall)pollSqlNode).getOperands()[0]; TableUtils.getFromTableInfo(fromNode, fromTableNameSet); - for(FieldReplaceInfo tmp : replaceInfoList){ - if(fromTableNameSet.contains(tmp.getTargetTableName()) - || fromTableNameSet.contains(tmp.getTargetTableAlias())){ - fieldReplaceInfo.setPreNode(tmp); - break; - } - } - replaceInfoList.add(fieldReplaceInfo); + } - private void joinFun(Object pollObj, Map localTableCache, - Map sideTableMap, StreamTableEnvironment tableEnv, - List replaceInfoList) throws Exception{ + private void joinFun(Object pollObj, + Map localTableCache, + Map sideTableMap, + StreamTableEnvironment tableEnv) throws Exception{ JoinInfo joinInfo = (JoinInfo) pollObj; JoinScope joinScope = new JoinScope(); @@ -768,11 +378,6 @@ private void joinFun(Object pollObj, Map localTableCache, leftScopeChild.setAlias(joinInfo.getLeftTableAlias()); leftScopeChild.setTableName(joinInfo.getLeftTableName()); - SqlKind sqlKind = joinInfo.getLeftNode().getKind(); - if(sqlKind == AS){ - dealAsSourceTable(tableEnv, joinInfo.getLeftNode(), localTableCache, replaceInfoList); - } - Table leftTable = getTableFromCache(localTableCache, joinInfo.getLeftTableAlias(), joinInfo.getLeftTableName()); RowTypeInfo leftTypeInfo = new RowTypeInfo(leftTable.getSchema().getTypes(), leftTable.getSchema().getColumnNames()); leftScopeChild.setRowTypeInfo(leftTypeInfo); @@ -798,8 +403,12 @@ private void joinFun(Object pollObj, Map localTableCache, joinScope.addScope(leftScopeChild); joinScope.addScope(rightScopeChild); + HashBasedTable mappingTable = ((JoinInfo) pollObj).getTableFieldRef(); + //获取两个表的所有字段 List sideJoinFieldInfo = ParserJoinField.getRowTypeInfo(joinInfo.getSelectNode(), joinScope, true); + //通过join的查询字段信息过滤出需要的字段信息 + sideJoinFieldInfo.removeIf(tmpFieldInfo -> mappingTable.get(tmpFieldInfo.getTable(), tmpFieldInfo.getFieldName()) == null); String leftTableAlias = joinInfo.getLeftTableAlias(); Table targetTable = localTableCache.get(leftTableAlias); @@ -830,8 +439,6 @@ private void joinFun(Object pollObj, Map localTableCache, dsOut = SideAsyncOperator.getSideJoinDataStream(adaptStream, sideTableInfo.getType(), localSqlPluginPath, typeInfo, joinInfo, sideJoinFieldInfo, sideTableInfo); } - // TODO 将嵌套表中的字段传递过去, 去除冗余的ROWtime - HashBasedTable mappingTable = HashBasedTable.create(); RowTypeInfo sideOutTypeInfo = buildOutRowTypeInfo(sideJoinFieldInfo, mappingTable); CRowTypeInfo cRowTypeInfo = new CRowTypeInfo(sideOutTypeInfo); @@ -845,17 +452,6 @@ private void joinFun(Object pollObj, Map localTableCache, replaceInfo.setTargetTableName(targetTableName); replaceInfo.setTargetTableAlias(targetTableAlias); - //判断之前是不是被替换过,被替换过则设置之前的替换信息作为上一个节点 - for(FieldReplaceInfo tmp : replaceInfoList){ - if(tmp.getTargetTableName().equalsIgnoreCase(joinInfo.getLeftTableName()) - ||tmp.getTargetTableName().equalsIgnoreCase(joinInfo.getLeftTableAlias())){ - replaceInfo.setPreNode(tmp); - break; - } - } - - replaceInfoList.add(replaceInfo); - if (!tableEnv.isRegistered(joinInfo.getNewTableName())){ Table joinTable = tableEnv.fromDataStream(dsOut); tableEnv.registerTable(joinInfo.getNewTableName(), joinTable); diff --git a/core/src/main/java/com/dtstack/flink/sql/util/FieldReplaceUtil.java b/core/src/main/java/com/dtstack/flink/sql/util/FieldReplaceUtil.java new file mode 100644 index 000000000..10919ca5b --- /dev/null +++ b/core/src/main/java/com/dtstack/flink/sql/util/FieldReplaceUtil.java @@ -0,0 +1,339 @@ +package com.dtstack.flink.sql.util; + +import com.dtstack.flink.sql.side.FieldReplaceInfo; +import com.google.common.collect.Lists; +import org.apache.calcite.sql.*; +import org.apache.calcite.sql.fun.SqlCase; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.commons.collections.CollectionUtils; + +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import static org.apache.calcite.sql.SqlKind.*; + +/** + * 替换 字段 + */ +public class FieldReplaceUtil { + + /** + * 需要考虑更多的情况 + */ + public static void replaceFieldName(SqlNode sqlNode, + String oldTbName, + String newTbName, + Map mappingField) { + SqlKind sqlKind = sqlNode.getKind(); + switch (sqlKind) { + case INSERT: + SqlNode sqlSource = ((SqlInsert) sqlNode).getSource(); + replaceFieldName(sqlSource, oldTbName, newTbName, mappingField); + break; + case AS: + SqlNode asNode = ((SqlBasicCall) sqlNode).getOperands()[0]; + replaceFieldName(asNode, oldTbName, newTbName, mappingField); + break; + case SELECT: + SqlSelect sqlSelect = (SqlSelect) sqlNode; + SqlNodeList sqlSelectList = sqlSelect.getSelectList(); + SqlNode whereNode = sqlSelect.getWhere(); + SqlNodeList sqlGroup = sqlSelect.getGroup(); + + //TODO 抽取,暂时使用使用单个join条件作为测试 + if(sqlSelect.getFrom().getKind().equals(JOIN)){ + SqlJoin joinNode = (SqlJoin) sqlSelect.getFrom(); + SqlNode joinCondition = joinNode.getCondition(); + replaceFieldName(((SqlBasicCall)joinCondition).operands[0], oldTbName, newTbName, mappingField); + replaceFieldName(((SqlBasicCall)joinCondition).operands[1], oldTbName, newTbName, mappingField); + } + + //TODO 暂时不处理having + SqlNode sqlHaving = sqlSelect.getHaving(); + + List newSelectNodeList = Lists.newArrayList(); + for( int i=0; i replaceNodeList = replaceSelectStarFieldName(selectNode, replaceInfo); + //newSelectNodeList.addAll(replaceNodeList); + throw new RuntimeException("not support table.* now"); + } + + SqlNode replaceNode = replaceSelectFieldName(selectNode, oldTbName, newTbName, mappingField); + if(replaceNode == null){ + continue; + } + + newSelectNodeList.add(replaceNode); + } + + SqlNodeList newSelectList = new SqlNodeList(newSelectNodeList, sqlSelectList.getParserPosition()); + sqlSelect.setSelectList(newSelectList); + + //where + if(whereNode != null){ + SqlNode[] sqlNodeList = ((SqlBasicCall)whereNode).getOperands(); + for(int i =0; i mappingField) { + if(orderNode.getKind() == IDENTIFIER){ + return createNewIdentify((SqlIdentifier) orderNode, oldTbName, newTbName, mappingField); + } else if (orderNode instanceof SqlBasicCall) { + SqlBasicCall sqlBasicCall = (SqlBasicCall) orderNode; + for(int i=0; i mappingField){ + if(groupNode.getKind() == IDENTIFIER){ + return createNewIdentify((SqlIdentifier) groupNode, oldTbName, newTbName, mappingField); + }else if(groupNode instanceof SqlBasicCall){ + SqlBasicCall sqlBasicCall = (SqlBasicCall) groupNode; + for(int i=0; i mappingField){ + + if (sqlIdentifier.names.size() == 1) { + return sqlIdentifier; + } + + String tableName = sqlIdentifier.names.get(0); + String fieldName = sqlIdentifier.names.get(1); + if(!tableName.equalsIgnoreCase(oldTbName)){ + return sqlIdentifier; + } + + String mappingFieldName = mappingField.get(fieldName); + if(mappingFieldName == null){ + return sqlIdentifier; + } + + sqlIdentifier = sqlIdentifier.setName(0, newTbName); + sqlIdentifier = sqlIdentifier.setName(1, mappingFieldName); + return sqlIdentifier; + } + + public static boolean filterNodeWithTargetName(SqlNode sqlNode, String targetTableName) { + + SqlKind sqlKind = sqlNode.getKind(); + switch (sqlKind){ + case SELECT: + SqlNode fromNode = ((SqlSelect)sqlNode).getFrom(); + if(fromNode.getKind() == AS && ((SqlBasicCall)fromNode).getOperands()[0].getKind() == IDENTIFIER){ + if(((SqlBasicCall)fromNode).getOperands()[0].toString().equalsIgnoreCase(targetTableName)){ + return true; + }else{ + return false; + } + }else{ + return filterNodeWithTargetName(fromNode, targetTableName); + } + case AS: + SqlNode aliasName = ((SqlBasicCall)sqlNode).getOperands()[1]; + return aliasName.toString().equalsIgnoreCase(targetTableName); + case JOIN: + SqlNode leftNode = ((SqlJoin)sqlNode).getLeft(); + SqlNode rightNode = ((SqlJoin)sqlNode).getRight(); + boolean leftReturn = filterNodeWithTargetName(leftNode, targetTableName); + boolean rightReturn = filterNodeWithTargetName(rightNode, targetTableName); + + return leftReturn || rightReturn; + + default: + return false; + } + } + + public static SqlNode replaceSelectFieldName(SqlNode selectNode, + String oldTbName, + String newTbName, + Map mappingField) { + + if (selectNode.getKind() == AS) { + SqlNode leftNode = ((SqlBasicCall) selectNode).getOperands()[0]; + SqlNode replaceNode = replaceSelectFieldName(leftNode, oldTbName, newTbName, mappingField); + if (replaceNode != null) { + ((SqlBasicCall) selectNode).getOperands()[0] = replaceNode; + } + + return selectNode; + }else if(selectNode.getKind() == IDENTIFIER){ + return createNewIdentify((SqlIdentifier) selectNode, oldTbName, newTbName, mappingField); + }else if(selectNode.getKind() == LITERAL || selectNode.getKind() == LITERAL_CHAIN){//字面含义 + return selectNode; + }else if( AGGREGATE.contains(selectNode.getKind()) + || AVG_AGG_FUNCTIONS.contains(selectNode.getKind()) + || COMPARISON.contains(selectNode.getKind()) + || selectNode.getKind() == OTHER_FUNCTION + || selectNode.getKind() == DIVIDE + || selectNode.getKind() == CAST + || selectNode.getKind() == TRIM + || selectNode.getKind() == TIMES + || selectNode.getKind() == PLUS + || selectNode.getKind() == NOT_IN + || selectNode.getKind() == OR + || selectNode.getKind() == AND + || selectNode.getKind() == MINUS + || selectNode.getKind() == TUMBLE + || selectNode.getKind() == TUMBLE_START + || selectNode.getKind() == TUMBLE_END + || selectNode.getKind() == SESSION + || selectNode.getKind() == SESSION_START + || selectNode.getKind() == SESSION_END + || selectNode.getKind() == HOP + || selectNode.getKind() == HOP_START + || selectNode.getKind() == HOP_END + || selectNode.getKind() == BETWEEN + || selectNode.getKind() == IS_NULL + || selectNode.getKind() == IS_NOT_NULL + || selectNode.getKind() == CONTAINS + || selectNode.getKind() == TIMESTAMP_ADD + || selectNode.getKind() == TIMESTAMP_DIFF + || selectNode.getKind() == LIKE + + ){ + SqlBasicCall sqlBasicCall = (SqlBasicCall) selectNode; + for(int i=0; i replaceSelectStarFieldName(SqlNode selectNode, FieldReplaceInfo replaceInfo){ + SqlIdentifier sqlIdentifier = (SqlIdentifier) selectNode; + List sqlNodes = Lists.newArrayList(); + if(sqlIdentifier.isStar()){//处理 [* or table.*] + int identifierSize = sqlIdentifier.names.size(); + Collection columns = null; + if(identifierSize == 1){ + columns = replaceInfo.getMappingTable().values(); + }else{ + columns = replaceInfo.getMappingTable().row(sqlIdentifier.names.get(0)).values(); + } + + for(String colAlias : columns){ + SqlParserPos sqlParserPos = new SqlParserPos(0, 0); + List columnInfo = Lists.newArrayList(); + columnInfo.add(replaceInfo.getTargetTableAlias()); + columnInfo.add(colAlias); + SqlIdentifier sqlIdentifierAlias = new SqlIdentifier(columnInfo, sqlParserPos); + sqlNodes.add(sqlIdentifierAlias); + } + + return sqlNodes; + }else{ + throw new RuntimeException("is not a star select field." + selectNode); + } + } + +} diff --git a/core/src/main/java/com/dtstack/flink/sql/util/ParseUtils.java b/core/src/main/java/com/dtstack/flink/sql/util/ParseUtils.java index d399b533c..2c5ccd9ca 100644 --- a/core/src/main/java/com/dtstack/flink/sql/util/ParseUtils.java +++ b/core/src/main/java/com/dtstack/flink/sql/util/ParseUtils.java @@ -37,14 +37,12 @@ package com.dtstack.flink.sql.util; import com.google.common.collect.HashBasedTable; +import com.google.common.collect.HashBiMap; import org.apache.calcite.sql.*; import org.apache.commons.lang3.StringUtils; import org.apache.flink.api.java.tuple.Tuple2; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; import static org.apache.calcite.sql.SqlKind.*; @@ -145,11 +143,42 @@ public static void fillFieldNameMapping(HashBasedTable m public static String dealDuplicateFieldName(HashBasedTable mappingTable, String fieldName) { String mappingFieldName = fieldName; - int index = 0; + int index = 1; while (!mappingTable.column(mappingFieldName).isEmpty()) { - mappingFieldName = mappingFieldName + index; + mappingFieldName = suffixWithChar(fieldName, '0', index); index++; } return mappingFieldName; } + + public static String dealDuplicateFieldName(HashBiMap refFieldMap, String fieldName) { + String mappingFieldName = fieldName; + int index = 1; + while (refFieldMap.inverse().get(mappingFieldName) != null ) { + mappingFieldName = suffixWithChar(fieldName, '0', index); + index++; + } + + return mappingFieldName; + } + + public static String dealDuplicateFieldName(Map refFieldMap, String fieldName) { + String mappingFieldName = fieldName; + int index = 1; + while (refFieldMap.containsKey(mappingFieldName)){ + mappingFieldName = suffixWithChar(fieldName, '0', index); + index++; + } + + return mappingFieldName; + } + + public static String suffixWithChar(String str, char padChar, int repeat){ + StringBuilder stringBuilder = new StringBuilder(str); + for(int i=0; i fromTableSet = Sets.newHashSet(); + TableUtils.getFromTableInfo(joinInfo.getLeftNode(), fromTableSet); + lefTbAlias = StringUtils.join(fromTableSet, "_"); + } + + String newTableAlias = !StringUtils.isEmpty(tableAlias) ? tableAlias : buildInternalTableName(lefTbAlias, SPLIT, joinInfo.getRightTableAlias()); if (null == sqlNode0) { sqlNode0 = new SqlIdentifier(newTableName, null, sqlParserPos); @@ -246,6 +257,21 @@ public static String dealSelectResultWithJoinInfo(JoinInfo joinInfo, SqlSelect s public static void replaceFromNodeForJoin(JoinInfo joinInfo, SqlSelect sqlNode) { //Update from node SqlBasicCall sqlBasicCall = buildAsNodeByJoinInfo(joinInfo, null, null); + String newAliasName = sqlBasicCall.operand(1).toString(); + + //替换select 中的属性为新的表名称和字段 + HashBasedTable fieldMapping = joinInfo.getTableFieldRef(); + Map leftFieldMapping = fieldMapping.row(joinInfo.getLeftTableAlias()); + Map rightFieldMapping = fieldMapping.row(joinInfo.getRightTableAlias()); + + /* for(SqlNode oneSelectNode : sqlNode.getSelectList()){ + replaceSelectFieldTable(oneSelectNode, joinInfo.getLeftTableAlias(), newAliasName, null ,leftFieldMapping); + replaceSelectFieldTable(oneSelectNode, joinInfo.getRightTableAlias(), newAliasName, null , rightFieldMapping); + }*/ + + //where中的条件属性为新的表名称和字段 + FieldReplaceUtil.replaceFieldName(sqlNode, joinInfo.getLeftTableAlias(), newAliasName, leftFieldMapping); + FieldReplaceUtil.replaceFieldName(sqlNode, joinInfo.getRightTableAlias(), newAliasName, rightFieldMapping); sqlNode.setFrom(sqlBasicCall); } @@ -277,10 +303,22 @@ public static void getFromTableInfo(SqlNode fromTable, Set tableNameSet) } } - public static void replaceSelectFieldTable(SqlNode selectNode, String oldTbName, String newTbName) { + /** + * 替换select 中的字段信息 + * 如果mappingTable 非空则从该参数获取字段的映射 + * 如果mappingTable 为空则根据是否存在新生成字段 + * @param selectNode + * @param oldTbName + * @param newTbName + * @param fieldReplaceRef + */ + public static void replaceSelectFieldTable(SqlNode selectNode, + String oldTbName, + String newTbName, + HashBiMap fieldReplaceRef) { if (selectNode.getKind() == AS) { SqlNode leftNode = ((SqlBasicCall) selectNode).getOperands()[0]; - replaceSelectFieldTable(leftNode, oldTbName, newTbName); + replaceSelectFieldTable(leftNode, oldTbName, newTbName, fieldReplaceRef); }else if(selectNode.getKind() == IDENTIFIER){ SqlIdentifier sqlIdentifier = (SqlIdentifier) selectNode; @@ -289,9 +327,9 @@ public static void replaceSelectFieldTable(SqlNode selectNode, String oldTbName, return ; } - if(oldTbName.equalsIgnoreCase(((SqlIdentifier)selectNode).names.get(0))){ - SqlIdentifier newField = ((SqlIdentifier)selectNode).setName(0, newTbName); - ((SqlIdentifier)selectNode).assignNamesFrom(newField); + String fieldTableName = sqlIdentifier.names.get(0); + if(oldTbName.equalsIgnoreCase(fieldTableName)){ + replaceOneSelectField(sqlIdentifier, newTbName, oldTbName, fieldReplaceRef); } }else if(selectNode.getKind() == LITERAL || selectNode.getKind() == LITERAL_CHAIN){//字面含义 @@ -338,7 +376,7 @@ public static void replaceSelectFieldTable(SqlNode selectNode, String oldTbName, continue; } - replaceSelectFieldTable(sqlNode, oldTbName, newTbName); + replaceSelectFieldTable(sqlNode, oldTbName, newTbName, fieldReplaceRef); } }else if(selectNode.getKind() == CASE){ @@ -349,16 +387,16 @@ public static void replaceSelectFieldTable(SqlNode selectNode, String oldTbName, for(int i=0; i fieldReplaceRef){ + SqlIdentifier newField = sqlIdentifier.setName(0, newTbName); + String fieldName = sqlIdentifier.names.get(1); + String fieldKey = oldTbName + "_" + fieldName; + + if(!fieldReplaceRef.containsKey(fieldKey)){ + if(fieldReplaceRef.inverse().get(fieldName) != null){ + //换一个名字 + String mappingFieldName = ParseUtils.dealDuplicateFieldName(fieldReplaceRef, fieldName); + newField = newField.setName(1, mappingFieldName); + fieldReplaceRef.put(fieldKey, mappingFieldName); + } else { + fieldReplaceRef.put(fieldKey, fieldName); + } + }else { + newField = newField.setName(1, fieldReplaceRef.get(fieldKey)); + } + + sqlIdentifier.assignNamesFrom(newField); + } + /** * 替换另外join 表的指定表名为新关联处理的表名称 * @param condition - * @param tableRef + * @param oldTabFieldRefNew */ - public static void replaceJoinFieldRefTableName(SqlNode condition, Map tableRef){ + public static void replaceJoinFieldRefTableName(SqlNode condition, Map oldTabFieldRefNew){ + if (null == condition || condition.getKind() == LITERAL) { + return; + } SqlKind joinKind = condition.getKind(); if( joinKind == AND || joinKind == EQUALS ){ - replaceJoinFieldRefTableName(((SqlBasicCall)condition).operands[0], tableRef); - replaceJoinFieldRefTableName(((SqlBasicCall)condition).operands[1], tableRef); + replaceJoinFieldRefTableName(((SqlBasicCall)condition).operands[0], oldTabFieldRefNew); + replaceJoinFieldRefTableName(((SqlBasicCall)condition).operands[1], oldTabFieldRefNew); }else{ Preconditions.checkState(((SqlIdentifier)condition).names.size() == 2, "join condition must be format table.field"); String fieldRefTable = ((SqlIdentifier)condition).names.get(0); - String targetTableName = TableUtils.getTargetRefTable(tableRef, fieldRefTable); - if(StringUtils.isNotBlank(targetTableName) && !fieldRefTable.equalsIgnoreCase(targetTableName)){ - SqlIdentifier newField = ((SqlIdentifier)condition).setName(0, targetTableName); + String targetFieldName = TableUtils.getTargetRefField(oldTabFieldRefNew, condition.toString()); + + if(StringUtils.isNotBlank(targetFieldName)){ + String[] fieldSplits = StringUtils.split(targetFieldName, "."); + SqlIdentifier newField = ((SqlIdentifier)condition).setName(0, fieldSplits[0]); + newField = newField.setName(1, fieldSplits[1]); ((SqlIdentifier)condition).assignNamesFrom(newField); } } @@ -401,7 +469,19 @@ public static String getTargetRefTable(Map refTableMap, String t return preTableName; } - public static void replaceWhereCondition(SqlNode parentWhere, String oldTbName, String newTbName){ + public static String getTargetRefField(Map refFieldMap, String currFieldName){ + String targetFieldName = null; + String preFieldName; + + do { + preFieldName = targetFieldName == null ? currFieldName : targetFieldName; + targetFieldName = refFieldMap.get(preFieldName); + } while (targetFieldName != null); + + return preFieldName; + } + + public static void replaceWhereCondition(SqlNode parentWhere, String oldTbName, String newTbName, HashBiMap fieldReplaceRef){ if(parentWhere == null){ return; @@ -409,15 +489,15 @@ public static void replaceWhereCondition(SqlNode parentWhere, String oldTbName, SqlKind kind = parentWhere.getKind(); if(kind == AND){ - replaceWhereCondition(((SqlBasicCall) parentWhere).getOperands()[0], oldTbName, newTbName); - replaceWhereCondition(((SqlBasicCall) parentWhere).getOperands()[1], oldTbName, newTbName); + replaceWhereCondition(((SqlBasicCall) parentWhere).getOperands()[0], oldTbName, newTbName, fieldReplaceRef); + replaceWhereCondition(((SqlBasicCall) parentWhere).getOperands()[1], oldTbName, newTbName, fieldReplaceRef); } else { - replaceConditionNode(parentWhere, oldTbName, newTbName); + replaceConditionNode(parentWhere, oldTbName, newTbName, fieldReplaceRef); } } - private static void replaceConditionNode(SqlNode selectNode, String oldTbName, String newTbName) { + private static void replaceConditionNode(SqlNode selectNode, String oldTbName, String newTbName, HashBiMap fieldReplaceRef) { if(selectNode.getKind() == IDENTIFIER){ SqlIdentifier sqlIdentifier = (SqlIdentifier) selectNode; @@ -426,8 +506,14 @@ private static void replaceConditionNode(SqlNode selectNode, String oldTbName, S } String tableName = sqlIdentifier.names.asList().get(0); + String tableField = sqlIdentifier.names.asList().get(1); + String fieldKey = tableName + "_" + tableField; + if(tableName.equalsIgnoreCase(oldTbName)){ + + String newFieldName = fieldReplaceRef.get(fieldKey) == null ? tableField : fieldReplaceRef.get(fieldKey); SqlIdentifier newField = ((SqlIdentifier)selectNode).setName(0, newTbName); + newField = newField.setName(1, newFieldName); ((SqlIdentifier)selectNode).assignNamesFrom(newField); } return; @@ -475,7 +561,7 @@ private static void replaceConditionNode(SqlNode selectNode, String oldTbName, S continue; } - replaceConditionNode(sqlNode, oldTbName, newTbName); + replaceConditionNode(sqlNode, oldTbName, newTbName, fieldReplaceRef); } return; @@ -490,18 +576,13 @@ private static void replaceConditionNode(SqlNode selectNode, String oldTbName, S /** * 获取条件中关联的表信息 * @param selectNode - * @param tableNameSet + * @param fieldInfos */ - public static void getConditionRefTable(SqlNode selectNode, Set tableNameSet) { + public static void getConditionRefTable(SqlNode selectNode, Set fieldInfos) { if(selectNode.getKind() == IDENTIFIER){ SqlIdentifier sqlIdentifier = (SqlIdentifier) selectNode; - if(sqlIdentifier.names.size() == 1){ - return; - } - - String tableName = sqlIdentifier.names.asList().get(0); - tableNameSet.add(tableName); + fieldInfos.add(sqlIdentifier.toString()); return; }else if(selectNode.getKind() == LITERAL || selectNode.getKind() == LITERAL_CHAIN){//字面含义 return; @@ -547,7 +628,7 @@ public static void getConditionRefTable(SqlNode selectNode, Set tableNam continue; } - getConditionRefTable(sqlNode, tableNameSet); + getConditionRefTable(sqlNode, fieldInfos); } return; @@ -558,4 +639,10 @@ public static void getConditionRefTable(SqlNode selectNode, Set tableNam throw new RuntimeException(String.format("not support node kind of %s to replace name now.", selectNode.getKind())); } } + + public static String buildTableField(String tableName, String fieldName){ + return String.format("%s.%s", tableName, fieldName); + } + + } diff --git a/docs/function.md b/docs/function.md new file mode 100644 index 000000000..e272011d2 --- /dev/null +++ b/docs/function.md @@ -0,0 +1,109 @@ +## 支持UDF,UDTF,UDAT: + +### UDTF使用案例 + +1. cross join:左表的每一行数据都会关联上UDTF 产出的每一行数据,如果UDTF不产出任何数据,那么这1行不会输出。 +2. left join:左表的每一行数据都会关联上UDTF 产出的每一行数据,如果UDTF不产出任何数据,则这1行的UDTF的字段会用null值填充。 left join UDTF 语句后面必须接 on true参数。 + + +场景:将某个字段拆分为两个字段。 + +```$xslt + +create table function UDTFOneColumnToMultiColumn with cn.todd.flink180.udflib.UDTFOneColumnToMultiColumn; + +CREATE TABLE MyTable ( + userID VARCHAR , + eventType VARCHAR, + productID VARCHAR) +WITH ( + type = 'kafka11', + bootstrapServers = '172.16.8.107:9092', + zookeeperQuorum = '172.16.8.107:2181/kafka', + offsetReset = 'latest', + topic ='mqTest03', + topicIsPattern = 'false' +); + +CREATE TABLE MyTable1 ( + channel VARCHAR , + pv VARCHAR, + name VARCHAR) +WITH ( + type = 'kafka11', + bootstrapServers = '172.16.8.107:9092', + zookeeperQuorum = '172.16.8.107:2181/kafka', + offsetReset = 'latest', + topic ='mqTest01', + topicIsPattern = 'false' +); + +CREATE TABLE MyTable2 ( + userID VARCHAR, + eventType VARCHAR, + productID VARCHAR, + date1 VARCHAR, + time1 VARCHAR +) +WITH ( + type = 'console', + bootstrapServers = '172.16.8.107:9092', + zookeeperQuorum = '172.16.8.107:2181/kafka', + offsetReset = 'latest', + topic ='mqTest02', + topicIsPattern = 'false' +); + +## 视图使用UDTF +--create view udtf_table as +-- select MyTable.userID,MyTable.eventType,MyTable.productID,date1,time1 + -- from MyTable LEFT JOIN lateral table(UDTFOneColumnToMultiColumn(productID)) + -- as T(date1,time1) on true; + + + + +insert + into + MyTable2 +select + userID,eventType,productID,date1,time1 +from + ( + select MyTable.userID,MyTable.eventType,MyTable.productID,date1,time1 + from MyTable ,lateral table(UDTFOneColumnToMultiColumn(productID)) as T(date1,time1)) as udtf_table; + +``` +一行转多列UDTFOneColumnToMultiColumn + +```$xslt +public class UDTFOneColumnToMultiColumn extends TableFunction { + public void eval(String value) { + String[] valueSplits = value.split("_"); + + //一行,两列 + Row row = new Row(2); + row.setField(0, valueSplits[0]); + row.setField(1, valueSplits[1]); + collect(row); + } + + @Override + public TypeInformation getResultType() { + return new RowTypeInfo(Types.STRING, Types.STRING); + } +} +``` + +输入输出: + + +输入: {"userID": "user_5", "eventType": "browse", "productID":"product_5"} + +输出: + + +--------+-----------+-----------+---------+-------+ + | userID | eventType | productID | date1 | time1 | + +--------+-----------+-----------+---------+-------+ + | user_5 | browse | product_5 | product | 5 | + +--------+-----------+-----------+---------+-------+ \ No newline at end of file diff --git a/rdb/rdb-side/src/main/java/com/dtstack/flink/sql/side/rdb/async/RdbAsyncReqRow.java b/rdb/rdb-side/src/main/java/com/dtstack/flink/sql/side/rdb/async/RdbAsyncReqRow.java index 68d7937dc..2d767ccd9 100644 --- a/rdb/rdb-side/src/main/java/com/dtstack/flink/sql/side/rdb/async/RdbAsyncReqRow.java +++ b/rdb/rdb-side/src/main/java/com/dtstack/flink/sql/side/rdb/async/RdbAsyncReqRow.java @@ -201,7 +201,6 @@ protected List getRows(CRow inputRow, List cacheContent, List entry : sideInfo.getInFieldIndex().entrySet()) { Object obj = input.getField(entry.getValue()); boolean isTimeIndicatorTypeInfo = TimeIndicatorTypeInfo.class.isAssignableFrom(sideInfo.getRowTypeInfo().getTypeAt(entry.getValue()).getClass()); @@ -216,7 +215,8 @@ public Row fillData(Row input, Object line) { if (jsonArray == null) { row.setField(entry.getKey(), null); } else { - Object object = SwitchUtil.getTarget(jsonArray.getValue(entry.getValue()), fields[entry.getValue()]); + String fieldType = sideInfo.getSelectSideFieldType(entry.getValue()); + Object object = SwitchUtil.getTarget(jsonArray.getValue(entry.getValue()), fieldType); row.setField(entry.getKey(), object); } }