Skip to content

Commit

Permalink
Add support for nested function calls
Browse files Browse the repository at this point in the history
  • Loading branch information
exAspArk committed Dec 12, 2024
1 parent de862a9 commit bb9f8d1
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 78 deletions.
2 changes: 1 addition & 1 deletion src/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"time"
)

const VERSION = "0.23.0"
const VERSION = "0.24.0"

func main() {
config := LoadConfig()
Expand Down
4 changes: 4 additions & 0 deletions src/query_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,10 @@ func TestHandleQuery(t *testing.T) {
"description": {"index"},
"values": {"1"},
},
"SELECT * FROM generate_series(1, array_upper(current_schemas(FALSE), 1)) AS series(index) LIMIT 1": {
"description": {"index"},
"values": {"1"},
},
// Transformed JOIN's
"SELECT s.usename, r.rolconfig FROM pg_catalog.pg_shadow s LEFT JOIN pg_catalog.pg_roles r ON s.usename = r.rolname": {
"description": {"usename", "rolconfig"},
Expand Down
113 changes: 61 additions & 52 deletions src/query_parser_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,14 @@ const (
PG_TABLE_TABLES = "tables"

PG_FUNCTION_PG_GET_KEYWORDS = "pg_get_keywords"
PG_FUNCTION_ARRAY_UPPER = "array_upper"
)

type QueryParserTable struct {
config *Config
utils *QueryUtils
}

type FunctionCall struct {
Schema string
Function string
Alias string
}

func NewQueryParserTable(config *Config) *QueryParserTable {
return &QueryParserTable{config: config, utils: NewQueryUtils(config)}
}
Expand All @@ -37,10 +32,9 @@ func (parser *QueryParserTable) NodeToSchemaTable(node *pgQuery.Node) SchemaTabl
rangeVar := node.GetRangeVar()
var alias string


if rangeVar.Alias != nil {
alias = rangeVar.Alias.Aliasname
}
if rangeVar.Alias != nil {
alias = rangeVar.Alias.Aliasname
}

return SchemaTable{
Schema: rangeVar.Schemaname,
Expand All @@ -49,45 +43,6 @@ func (parser *QueryParserTable) NodeToSchemaTable(node *pgQuery.Node) SchemaTabl
}
}

func (parser *QueryParserTable) NodeToFunctionCalls(node *pgQuery.Node) []FunctionCall {
var functionCalls []FunctionCall
rangeFunction := node.GetRangeFunction()

var alias string
if rangeFunction.Alias != nil {
alias = rangeFunction.Alias.Aliasname
}

for _, functionNode := range rangeFunction.Functions {
for _, item := range functionNode.GetList().Items {
funcCall := item.GetFuncCall()
if funcCall == nil {
continue
}

var functionCall FunctionCall

switch len(funcCall.Funcname) {
case 1:
functionCall = FunctionCall{
Function: funcCall.Funcname[0].GetString_().Sval,
Alias: alias,
}
case 2:
functionCall = FunctionCall{
Schema: funcCall.Funcname[0].GetString_().Sval,
Function: funcCall.Funcname[1].GetString_().Sval,
Alias: alias,
}
}

functionCalls = append(functionCalls, functionCall)
}
}

return functionCalls
}

// pg_catalog.pg_statio_user_tables
func (parser *QueryParserTable) IsPgStatioUserTablesTable(schemaTable SchemaTable) bool {
return parser.isPgCatalogSchema(schemaTable) && schemaTable.Table == PG_TABLE_PG_STATIO_USER_TABLES
Expand Down Expand Up @@ -242,12 +197,30 @@ func (parser *QueryParserTable) MakeIcebergTableNode(tablePath string) *pgQuery.
}

// pg_catalog.pg_get_keywords()
func (parser *QueryParserTable) IsPgGetKeywordsFunction(functionCall FunctionCall) bool {
return functionCall.Schema == PG_SCHEMA_PG_CATALOG && functionCall.Function == PG_FUNCTION_PG_GET_KEYWORDS
func (parser *QueryParserTable) IsPgGetKeywordsFunction(node *pgQuery.Node) bool {
for _, funcNode := range node.GetRangeFunction().Functions {
for _, funcItemNode := range funcNode.GetList().Items {
funcCallNode := funcItemNode.GetFuncCall()
if funcCallNode == nil {
continue
}
if len(funcCallNode.Funcname) != 2 {
continue
}

schema := funcCallNode.Funcname[0].GetString_().Sval
function := funcCallNode.Funcname[1].GetString_().Sval
if schema == PG_SCHEMA_PG_CATALOG && function == PG_FUNCTION_PG_GET_KEYWORDS {
return true
}
}
}

return false
}

// pg_catalog.pg_get_keywords() -> VALUES(values...) t(columns...)
func (parser *QueryParserTable) MakePgGetKeywordsNode(alias string) *pgQuery.Node {
func (parser *QueryParserTable) MakePgGetKeywordsNode(node *pgQuery.Node) *pgQuery.Node {
columns := []string{"word", "catcode", "barelabel", "catdesc", "baredesc"}

var rows [][]string
Expand Down Expand Up @@ -277,9 +250,45 @@ func (parser *QueryParserTable) MakePgGetKeywordsNode(alias string) *pgQuery.Nod
rows = append(rows, row)
}

var alias string
if node.GetAlias() != nil {
alias = node.GetAlias().Aliasname
}

return parser.utils.MakeSubselectNode(columns, rows, alias)
}

// array_upper(array, 1)
func (parser *QueryParserTable) IsArrayUpperFunction(funcCallNode *pgQuery.FuncCall) bool {
if len(funcCallNode.Funcname) != 1 {
return false
}

funcName := funcCallNode.Funcname[0].GetString_().Sval

if funcName == PG_FUNCTION_ARRAY_UPPER {
dimension := funcCallNode.Args[1].GetAConst().GetIval().Ival
if dimension == 1 {
return true
}
}

return false
}

// array_upper(array, 1) -> len(array)
func (parser *QueryParserTable) MakeArrayUpperNode(funcCallNode *pgQuery.FuncCall) *pgQuery.FuncCall {
return pgQuery.MakeFuncCallNode(
[]*pgQuery.Node{
pgQuery.MakeStrNode("len"),
},
[]*pgQuery.Node{
funcCallNode.Args[0],
},
0,
).GetFuncCall()
}

func (parser *QueryParserTable) isPgCatalogSchema(schemaTable SchemaTable) bool {
return schemaTable.Schema == PG_SCHEMA_PG_CATALOG || schemaTable.Schema == ""
}
Expand Down
97 changes: 77 additions & 20 deletions src/select_remapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@ func NewSelectRemapper(config *Config, icebergReader *IcebergReader) *SelectRema
}
}

// SELECT ...
func (selectRemapper *SelectRemapper) RemapQueryTreeWithSelect(queryTree *pgQuery.ParseResult) *pgQuery.ParseResult {
selectStatement := queryTree.Stmts[0].Stmt.GetSelectStmt()
selectStatement = selectRemapper.remapSelectStatement(selectStatement, 0)

return queryTree
}

// No-op
// SET ... (no-op)
func (selectRemapper *SelectRemapper) RemapQueryTreeWithSet(queryTree *pgQuery.ParseResult) *pgQuery.ParseResult {
setStatement := queryTree.Stmts[0].Stmt.GetVariableSetStmt()

Expand All @@ -70,48 +71,100 @@ func (selectRemapper *SelectRemapper) RemapQueryTreeWithSet(queryTree *pgQuery.P
func (selectRemapper *SelectRemapper) remapSelectStatement(selectStatement *pgQuery.SelectStmt, indentLevel int) *pgQuery.SelectStmt {
selectStatement = selectRemapper.remapTypeCastsInSelect(selectStatement)

// UNION
if selectStatement.FromClause == nil && selectStatement.Larg != nil && selectStatement.Rarg != nil {
LogDebug(selectRemapper.config, strings.Repeat(">", indentLevel+1)+" UNION left")
selectRemapper.logTreeTraversal("UNION left", indentLevel)
leftSelectStatement := selectStatement.Larg
leftSelectStatement = selectRemapper.remapSelectStatement(leftSelectStatement, indentLevel+1)
leftSelectStatement = selectRemapper.remapSelectStatement(leftSelectStatement, indentLevel+1) // self-recursion

LogDebug(selectRemapper.config, strings.Repeat(">", indentLevel+1)+" UNION right")
selectRemapper.logTreeTraversal("UNION right", indentLevel)
rightSelectStatement := selectStatement.Rarg
rightSelectStatement = selectRemapper.remapSelectStatement(rightSelectStatement, indentLevel+1)
rightSelectStatement = selectRemapper.remapSelectStatement(rightSelectStatement, indentLevel+1) // self-recursion

return selectStatement
}

// JOIN
if len(selectStatement.FromClause) > 0 && selectStatement.FromClause[0].GetJoinExpr() != nil {
selectStatement = selectRemapper.remapSelect(selectStatement, indentLevel)
selectRemapper.remapJoinExpressions(selectStatement.FromClause[0], indentLevel)
// SELECT
selectStatement = selectRemapper.remapSelect(selectStatement, indentLevel) // recursive
selectStatement.FromClause[0] = selectRemapper.remapJoinExpressions(selectStatement.FromClause[0], indentLevel) // recursive with self-recursion
return selectStatement
}

// FROM
if len(selectStatement.FromClause) > 0 {
// WHERE
if selectStatement.FromClause[0].GetRangeVar() != nil {
selectRemapper.logTreeTraversal("WHERE statements", indentLevel)
selectStatement = selectRemapper.remapperWhere.RemapWhere(selectStatement)
}
selectStatement = selectRemapper.remapSelect(selectStatement, indentLevel)

// SELECT
selectStatement = selectRemapper.remapSelect(selectStatement, indentLevel) // recursive

for i, fromNode := range selectStatement.FromClause {
if fromNode.GetRangeVar() != nil {
LogDebug(selectRemapper.config, strings.Repeat(">", indentLevel+1)+" SELECT statement")
selectRemapper.logTreeTraversal("FROM table", indentLevel)
selectStatement.FromClause[i] = selectRemapper.remapperTable.RemapTable(fromNode)
} else if fromNode.GetRangeSubselect() != nil {
selectRemapper.remapSelectStatement(fromNode.GetRangeSubselect().Subquery.GetSelectStmt(), indentLevel+1)
selectRemapper.logTreeTraversal("FROM subselect", indentLevel)
subSelectStatement := fromNode.GetRangeSubselect().Subquery.GetSelectStmt()
subSelectStatement = selectRemapper.remapSelectStatement(subSelectStatement, indentLevel+1) // self-recursion
}

if fromNode.GetRangeFunction() != nil {
selectStatement.FromClause[i] = selectRemapper.remapperTable.RemapTableFunction(fromNode)
selectStatement.FromClause[i] = selectRemapper.remapTableFunction(fromNode, indentLevel+1) // recursive
}
}
return selectStatement
}

selectStatement = selectRemapper.remapSelect(selectStatement, indentLevel)
selectStatement = selectRemapper.remapSelect(selectStatement, indentLevel) // recursive
return selectStatement
}

// FROM PG_FUNCTION()
func (selectRemapper *SelectRemapper) remapTableFunction(fromNode *pgQuery.Node, indentLevel int) *pgQuery.Node {
selectRemapper.logTreeTraversal("FROM function()", indentLevel)

fromNode = selectRemapper.remapperTable.RemapTableFunction(fromNode)
if fromNode.GetRangeFunction() == nil {
return fromNode
}

for _, funcNode := range fromNode.GetRangeFunction().Functions {
for _, funcItemNode := range funcNode.GetList().Items {
funcCallNode := funcItemNode.GetFuncCall()
if funcCallNode == nil {
continue
}
funcCallNode = selectRemapper.remapTableFunctionArgs(funcCallNode, indentLevel+1) // recursive
}
}

return fromNode
}

// FROM PG_FUNCTION(PG_NESTED_FUNCTION())
func (selectRemapper *SelectRemapper) remapTableFunctionArgs(funcCallNode *pgQuery.FuncCall, indentLevel int) *pgQuery.FuncCall {
selectRemapper.logTreeTraversal("FROM nested_function()", indentLevel)

for i, argNode := range funcCallNode.GetArgs() {
nestedFunctionCall := argNode.GetFuncCall()
if nestedFunctionCall == nil {
continue
}

nestedFunctionCall = selectRemapper.remapperTable.RemapNestedTableFunction(nestedFunctionCall)
nestedFunctionCall = selectRemapper.remapTableFunctionArgs(nestedFunctionCall, indentLevel+1) // recursive

funcCallNode.Args[i].Node = &pgQuery.Node_FuncCall{FuncCall: nestedFunctionCall}
}

return funcCallNode
}

func (selectRemapper *SelectRemapper) remapTypeCastsInSelect(selectStatement *pgQuery.SelectStmt) *pgQuery.SelectStmt {
// WHERE [CONDITION]
if selectStatement.WhereClause != nil {
Expand Down Expand Up @@ -205,43 +258,43 @@ func (selectRemapper *SelectRemapper) remapTypeCastsInNode(node *pgQuery.Node) *
}

func (selectRemapper *SelectRemapper) remapJoinExpressions(node *pgQuery.Node, indentLevel int) *pgQuery.Node {
LogDebug(selectRemapper.config, strings.Repeat(">", indentLevel+1)+" JOIN left")
selectRemapper.logTreeTraversal("JOIN left", indentLevel+1)
leftJoinNode := node.GetJoinExpr().Larg
if leftJoinNode.GetJoinExpr() != nil {
leftJoinNode = selectRemapper.remapJoinExpressions(leftJoinNode, indentLevel+1)
leftJoinNode = selectRemapper.remapJoinExpressions(leftJoinNode, indentLevel+1) // self-recursion
} else if leftJoinNode.GetRangeVar() != nil {
leftJoinNode = selectRemapper.remapperTable.RemapTable(leftJoinNode)
} else if leftJoinNode.GetRangeSubselect() != nil {
leftSelectStatement := leftJoinNode.GetRangeSubselect().Subquery.GetSelectStmt()
leftSelectStatement = selectRemapper.remapSelectStatement(leftSelectStatement, indentLevel+1)
leftSelectStatement = selectRemapper.remapSelectStatement(leftSelectStatement, indentLevel+1) // parent-recursion
}
node.GetJoinExpr().Larg = leftJoinNode

LogDebug(selectRemapper.config, strings.Repeat(">", indentLevel+1)+" JOIN right")
selectRemapper.logTreeTraversal("JOIN right", indentLevel+1)
rightJoinNode := node.GetJoinExpr().Rarg
if rightJoinNode.GetJoinExpr() != nil {
rightJoinNode = selectRemapper.remapJoinExpressions(rightJoinNode, indentLevel+1)
rightJoinNode = selectRemapper.remapJoinExpressions(rightJoinNode, indentLevel+1) // self-recursion
} else if rightJoinNode.GetRangeVar() != nil {
rightJoinNode = selectRemapper.remapperTable.RemapTable(rightJoinNode)
} else if rightJoinNode.GetRangeSubselect() != nil {
rightSelectStatement := rightJoinNode.GetRangeSubselect().Subquery.GetSelectStmt()
rightSelectStatement = selectRemapper.remapSelectStatement(rightSelectStatement, indentLevel+1)
rightSelectStatement = selectRemapper.remapSelectStatement(rightSelectStatement, indentLevel+1) // parent-recursion
}
node.GetJoinExpr().Rarg = rightJoinNode

return node
}

func (selectRemapper *SelectRemapper) remapSelect(selectStatement *pgQuery.SelectStmt, indentLevel int) *pgQuery.SelectStmt {
LogDebug(selectRemapper.config, strings.Repeat(">", indentLevel+1)+" SELECT functions")
selectRemapper.logTreeTraversal("SELECT statements", indentLevel+1)

for i, targetNode := range selectStatement.TargetList {
targetNode = selectRemapper.remapperSelect.RemapSelect(targetNode)

// Recursively remap sub-selects
subSelectStatement := selectRemapper.remapperSelect.SubselectStatement(targetNode)
if subSelectStatement != nil {
subSelectStatement = selectRemapper.remapSelect(subSelectStatement, indentLevel+1)
subSelectStatement = selectRemapper.remapSelect(subSelectStatement, indentLevel+1) // self-recursion
}

selectStatement.TargetList[i] = targetNode
Expand Down Expand Up @@ -285,3 +338,7 @@ func (selectRemapper *SelectRemapper) remapTypecast(node *pgQuery.Node) *pgQuery
}
return node
}

func (selectRemapper *SelectRemapper) logTreeTraversal(label string, indentLevel int) {
LogDebug(selectRemapper.config, strings.Repeat(">", indentLevel), label)
}
18 changes: 13 additions & 5 deletions src/select_remapper_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,24 @@ func (remapper *SelectRemapperTable) RemapTable(node *pgQuery.Node) *pgQuery.Nod

// FROM [PG_FUNCTION()]
func (remapper *SelectRemapperTable) RemapTableFunction(node *pgQuery.Node) *pgQuery.Node {
for _, functionCall := range remapper.parserTable.NodeToFunctionCalls(node) {
// pg_catalog.pg_get_keywords() -> hard-coded keywords
if remapper.parserTable.IsPgGetKeywordsFunction(functionCall) {
return remapper.parserTable.MakePgGetKeywordsNode(functionCall.Alias)
}
// pg_catalog.pg_get_keywords() -> hard-coded keywords
if remapper.parserTable.IsPgGetKeywordsFunction(node) {
return remapper.parserTable.MakePgGetKeywordsNode(node)
}

return node
}

// FROM PG_FUNCTION(PG_NESTED_FUNCTION())
func (remapper *SelectRemapperTable) RemapNestedTableFunction(funcCallNode *pgQuery.FuncCall) *pgQuery.FuncCall {
// array_upper(values, 1) -> len(values)
if remapper.parserTable.IsArrayUpperFunction(funcCallNode) {
return remapper.parserTable.MakeArrayUpperNode(funcCallNode)
}

return funcCallNode
}

func (remapper *SelectRemapperTable) overrideTable(node *pgQuery.Node, fromClause *pgQuery.Node) *pgQuery.Node {
node = fromClause
return node
Expand Down

0 comments on commit bb9f8d1

Please sign in to comment.