Skip to content

Commit

Permalink
Merge pull request #49 from lighttiger2505/fix-multi-value-insert
Browse files Browse the repository at this point in the history
Fix multi value insert
  • Loading branch information
lighttiger2505 authored Mar 9, 2021
2 parents 4cea80a + 7767dad commit 472d06b
Show file tree
Hide file tree
Showing 10 changed files with 220 additions and 51 deletions.
6 changes: 3 additions & 3 deletions ast/astutil/astutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (nr *NodeReader) CurNodeIs(nm NodeMatcher) bool {
return false
}

func isEnclose(node ast.Node, pos token.Pos) bool {
func IsEnclose(node ast.Node, pos token.Pos) bool {
if 0 <= token.ComparePos(pos, node.Pos()) && 0 >= token.ComparePos(pos, node.End()) {
return true
}
Expand All @@ -176,15 +176,15 @@ func isEnclose(node ast.Node, pos token.Pos) bool {

func (nr *NodeReader) CurNodeEncloseIs(pos token.Pos) bool {
if nr.CurNode != nil {
return isEnclose(nr.CurNode, pos)
return IsEnclose(nr.CurNode, pos)
}
return false
}

func (nr *NodeReader) PeekNodeEncloseIs(pos token.Pos) bool {
_, peekNode := nr.PeekNode(false)
if peekNode != nil {
return isEnclose(peekNode, pos)
return IsEnclose(peekNode, pos)
}
return false
}
Expand Down
69 changes: 60 additions & 9 deletions internal/handler/signature_help_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package handler

import (
"fmt"
"testing"

"github.com/google/go-cmp/cmp"
Expand All @@ -17,19 +18,36 @@ type signatureHelpTestCase struct {
want lsp.SignatureHelp
}

// input is "insert into city (ID, Name, CountryCode) VALUES (123, 'aaa', '2020')"
var signatureHelpTestCases = []signatureHelpTestCase{
genInsertPositionTest(50, 0),
genInsertPositionTest(52, 0),
genInsertPositionTest(53, 1),
genInsertPositionTest(59, 1),
genInsertPositionTest(60, 2),
genInsertPositionTest(67, 2),
// single record
// input is "insert into city (ID, Name, CountryCode) VALUES (123, 'aaa', '2020')"
genSingleRecordInsertTest(50, 0),
genSingleRecordInsertTest(52, 0),
genSingleRecordInsertTest(53, 1),
genSingleRecordInsertTest(59, 1),
genSingleRecordInsertTest(60, 2),
genSingleRecordInsertTest(67, 2),

// multi record
// input is "insert into city (ID, Name, CountryCode) VALUES (123, 'aaa', '2020'), (456, 'bbb', '2021')"
genMultiRecordInsertTest(50, 0),
genMultiRecordInsertTest(52, 0),
genMultiRecordInsertTest(53, 1),
genMultiRecordInsertTest(59, 1),
genMultiRecordInsertTest(60, 2),
genMultiRecordInsertTest(67, 2),

genMultiRecordInsertTest(72, 0),
genMultiRecordInsertTest(74, 0),
genMultiRecordInsertTest(76, 1),
genMultiRecordInsertTest(81, 1),
genMultiRecordInsertTest(83, 2),
genMultiRecordInsertTest(89, 2),
}

func genInsertPositionTest(col int, wantActiveParameter int) signatureHelpTestCase {
func genSingleRecordInsertTest(col int, wantActiveParameter int) signatureHelpTestCase {
return signatureHelpTestCase{
name: "",
name: fmt.Sprintf("single record %d-%d", col, wantActiveParameter),
input: "insert into city (ID, Name, CountryCode) VALUES (123, 'aaa', '2020')",
line: 0,
col: col,
Expand Down Expand Up @@ -60,6 +78,39 @@ func genInsertPositionTest(col int, wantActiveParameter int) signatureHelpTestCa
}
}

func genMultiRecordInsertTest(col int, wantActiveParameter int) signatureHelpTestCase {
return signatureHelpTestCase{
name: fmt.Sprintf("multi record %d-%d", col, wantActiveParameter),
input: "insert into city (ID, Name, CountryCode) VALUES (123, 'aaa', '2020'), (456, 'bbb', '2021')",
line: 0,
col: col,
want: lsp.SignatureHelp{
Signatures: []lsp.SignatureInformation{
{
Label: "city (ID, Name, CountryCode)",
Documentation: "city table columns",
Parameters: []lsp.ParameterInformation{
{
Label: "ID",
Documentation: "int(11) PRI auto_increment",
},
{
Label: "Name",
Documentation: "char(35)",
},
{
Label: "CountryCode",
Documentation: "char(3) MUL",
},
},
},
},
ActiveSignature: 0.0,
ActiveParameter: float64(wantActiveParameter),
},
}
}

func TestSignatureHelpMain(t *testing.T) {
tx := newTestContext()
tx.initServer(t)
Expand Down
35 changes: 11 additions & 24 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ type (
func parsePrefixGroup(reader *astutil.NodeReader, matcher astutil.NodeMatcher, fn prefixParseFn) ast.TokenList {
var replaceNodes []ast.Node
for reader.NextNode(false) {
if list, ok := reader.CurNode.(ast.TokenList); ok {
newReader := astutil.NewNodeReader(list)
replaceNode := parsePrefixGroup(newReader, matcher, fn)
reader.Replace(replaceNode, reader.Index-1)
}
if reader.CurNodeIs(matcher) {
replaceNodes = append(replaceNodes, fn(reader))
} else if list, ok := reader.CurNode.(ast.TokenList); ok {
newReader := astutil.NewNodeReader(list)
replaceNodes = append(replaceNodes, parsePrefixGroup(newReader, matcher, fn))
} else {
replaceNodes = append(replaceNodes, reader.CurNode)
}
Expand All @@ -36,11 +38,13 @@ func parsePrefixGroup(reader *astutil.NodeReader, matcher astutil.NodeMatcher, f
func parseInfixGroup(reader *astutil.NodeReader, matcher astutil.NodeMatcher, ignoreWhiteSpace bool, fn infixParseFn) ast.TokenList {
var replaceNodes []ast.Node
for reader.NextNode(false) {
if list, ok := reader.CurNode.(ast.TokenList); ok {
newReader := astutil.NewNodeReader(list)
replaceNode := parseInfixGroup(newReader, matcher, ignoreWhiteSpace, fn)
reader.Replace(replaceNode, reader.Index-1)
}
if reader.PeekNodeIs(ignoreWhiteSpace, matcher) {
replaceNodes = append(replaceNodes, fn(reader))
} else if list, ok := reader.CurNode.(ast.TokenList); ok {
newReader := astutil.NewNodeReader(list)
replaceNodes = append(replaceNodes, parseInfixGroup(newReader, matcher, ignoreWhiteSpace, fn))
} else {
replaceNodes = append(replaceNodes, reader.CurNode)
}
Expand Down Expand Up @@ -95,11 +99,11 @@ func (p *Parser) Parse() (ast.TokenList, error) {
root = parsePrefixGroup(astutil.NewNodeReader(root), parenthesisPrefixMatcher, parseParenthesis)
root = parsePrefixGroup(astutil.NewNodeReader(root), functionPrefixMatcher, parseFunctions)
root = parsePrefixGroup(astutil.NewNodeReader(root), identifierPrefixMatcher, parseIdentifier)
root = parseInfixGroup(astutil.NewNodeReader(root), memberIdentifierInfixMatcher, false, parseMemberIdentifier)
root = parsePrefixGroup(astutil.NewNodeReader(root), switchCaseOpenMatcher, parseCase)

root = parsePrefixGroup(astutil.NewNodeReader(root), expressionPrefixMatcher, parseExpressionInParenthesis)

root = parseInfixGroup(astutil.NewNodeReader(root), memberIdentifierInfixMatcher, false, parseMemberIdentifier)
root = parsePrefixGroup(astutil.NewNodeReader(root), genMultiKeywordPrefixMatcher(), parseMultiKeyword)
root = parseInfixGroup(astutil.NewNodeReader(root), operatorInfixMatcher, true, parseOperator)
root = parseInfixGroup(astutil.NewNodeReader(root), comparisonInfixMatcher, true, parseComparison)
Expand Down Expand Up @@ -480,15 +484,6 @@ var aliasRecursionMatcher = astutil.NodeMatcher{
}

func parseAliasedWithoutAs(reader *astutil.NodeReader) ast.Node {
if reader.CurNodeIs(aliasRecursionMatcher) {
if list, ok := reader.CurNode.(ast.TokenList); ok {
// FIXME: more simplity
// For sub query
parenthesis := parsePrefixGroup(astutil.NewNodeReader(list), aliasLeftMatcher, parseAliasedWithoutAs)
reader.Replace(parenthesis, reader.Index-1)
}
}

if !reader.PeekNodeIs(true, aliasRightMatcher) {
return reader.CurNode
}
Expand All @@ -510,14 +505,6 @@ func parseAliased(reader *astutil.NodeReader) ast.Node {
if !reader.CurNodeIs(aliasLeftMatcher) {
return reader.CurNode
}
if reader.CurNodeIs(aliasRecursionMatcher) {
if list, ok := reader.CurNode.(ast.TokenList); ok {
// FIXME: more simplity
// For sub query
parenthesis := parseInfixGroup(astutil.NewNodeReader(list), aliasInfixMatcher, true, parseAliased)
reader.Replace(parenthesis, reader.Index-1)
}
}

realName := reader.CurNode
_, as := reader.PeekNode(true)
Expand Down
46 changes: 43 additions & 3 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,27 @@ func TestParseAliased(t *testing.T) {
testItem(t, parenthesis[8], ")")
},
},
{
name: "aliase in parenthesis",
input: "(SELECT ci.ID AS city_id, ci.Name AS city_name FROM world.city AS ci)",
checkFn: func(t *testing.T, stmts []*ast.Statement, input string) {
testStatement(t, stmts[0], 1, input)

list := stmts[0].GetTokens()
testParenthesis(t, list[0], input)

parenthesis := testTokenList(t, list[0], 9).GetTokens()
testItem(t, parenthesis[0], "(")
testItem(t, parenthesis[1], "SELECT")
testItem(t, parenthesis[2], " ")
testIdentifierList(t, parenthesis[3], "ci.ID AS city_id, ci.Name AS city_name")
testItem(t, parenthesis[4], " ")
testItem(t, parenthesis[5], "FROM")
testItem(t, parenthesis[6], " ")
testAliased(t, parenthesis[7], "world.city AS ci", "world.city", "ci")
testItem(t, parenthesis[8], ")")
},
},
}
for _, tt := range testcases {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -1039,6 +1060,25 @@ func TestParseIdentifierList(t *testing.T) {
testIdentifierList_GetIndex(t, il, genPosOneline(18), -1)
},
},
{
name: "parenthesis list",
input: "(foo, bar, foobar), (fooo, barr, fooobarr)",
checkFn: func(t *testing.T, stmts []*ast.Statement, input string) {
testStatement(t, stmts[0], 4, input)
list := stmts[0].GetTokens()

parenthesis1 := testParenthesis(t, list[0], "(foo, bar, foobar)")
tokens1 := parenthesis1.Inner().GetTokens()
testIdentifierList(t, tokens1[0], "foo, bar, foobar")

testItem(t, list[1], ",")
testItem(t, list[2], " ")

parenthesis2 := testParenthesis(t, list[3], "(fooo, barr, fooobarr)")
tokens2 := parenthesis2.Inner().GetTokens()
testIdentifierList(t, tokens2[0], "fooo, barr, fooobarr")
},
},
{
name: "invalid parenthesis",
input: "(foo, bar,",
Expand Down Expand Up @@ -1374,13 +1414,13 @@ func testAliased(t *testing.T, node ast.Node, expect string, realName, aliasedNa

func testIdentifierList(t *testing.T, node ast.Node, expect string) *ast.IdentiferList {
t.Helper()
if expect != node.String() {
t.Errorf("expected %q, got %q", expect, node.String())
}
il, ok := node.(*ast.IdentiferList)
if !ok {
t.Fatalf("invalid type want IdentiferList got %T", node)
}
if expect != node.String() {
t.Errorf("expected %q, got %q", expect, node.String())
}
return il
}

Expand Down
28 changes: 22 additions & 6 deletions parser/parseutil/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package parseutil
import (
"github.com/lighttiger2505/sqls/ast"
"github.com/lighttiger2505/sqls/ast/astutil"
"github.com/lighttiger2505/sqls/token"
)

type (
Expand Down Expand Up @@ -159,6 +160,7 @@ func parseInsertColumns(reader *astutil.NodeReader) []ast.Node {
ast.TypeParenthesis,
},
}

if !reader.PeekNodeIs(true, insertColumnsParenthesis) {
return []ast.Node{}
}
Expand All @@ -168,20 +170,34 @@ func parseInsertColumns(reader *astutil.NodeReader) []ast.Node {
if !ok {
return []ast.Node{}
}
identList, ok := parenthesis.Inner().GetTokens()[0].(*ast.IdentiferList)
if !ok {
return []ast.Node{}

inner, ok := parenthesis.Inner().(*ast.IdentiferList)
if ok {
return []ast.Node{inner}
}
return []ast.Node{identList}
firstToken, ok := parenthesis.Inner().GetTokens()[0].(*ast.IdentiferList)
if ok {
return []ast.Node{firstToken}
}
return []ast.Node{}
}

func ExtractInsertValues(parsed ast.TokenList) []ast.Node {
func ExtractInsertValues(parsed ast.TokenList, pos token.Pos) []ast.Node {
insertTableIdentifer := astutil.NodeMatcher{
ExpectTokens: []token.Kind{
token.Comma,
},
ExpectKeyword: []string{
"VALUES",
},
}
return parsePrefix(astutil.NewNodeReader(parsed), insertTableIdentifer, parseInsertColumns)
values := parsePrefix(astutil.NewNodeReader(parsed), insertTableIdentifer, parseInsertValues)
for _, v := range values {
if astutil.IsEnclose(v, pos) {
return []ast.Node{v}
}
}
return []ast.Node{}
}

func parseInsertValues(reader *astutil.NodeReader) []ast.Node {
Expand Down
28 changes: 27 additions & 1 deletion parser/parseutil/extract_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package parseutil

import (
"testing"

"github.com/lighttiger2505/sqls/token"
)

func TestExtractSelectExpr(t *testing.T) {
Expand Down Expand Up @@ -279,25 +281,49 @@ func TestExtractInsertValues(t *testing.T) {
testcases := []struct {
name string
input string
pos token.Pos
want []string
}{
{
name: "full",
input: "insert into city (ID, Name, CountryCode) VALUES (123, 'aaa', '2020')",
pos: token.Pos{
Line: 0,
Col: 50,
},
want: []string{
"123, 'aaa', '2020'",
},
},
{
name: "multi value",
input: "insert into city (ID, Name, CountryCode) VALUES (123, 'aaa', '2020'), (456, 'bbb', '2021')",
pos: token.Pos{
Line: 0,
Col: 72,
},
want: []string{
"456, 'bbb', '2021'",
},
},
{
name: "with out statement",
input: "city (ID, Name, CountryCode) VALUES (123, 'aaa', '2020')",
pos: token.Pos{
Line: 0,
Col: 38,
},
want: []string{
"123, 'aaa', '2020'",
},
},
{
name: "minimum",
input: "VALUES (123, 'aaa', '2020')",
pos: token.Pos{
Line: 0,
Col: 9,
},
want: []string{
"123, 'aaa', '2020'",
},
Expand All @@ -306,7 +332,7 @@ func TestExtractInsertValues(t *testing.T) {
for _, tt := range testcases {
t.Run(tt.name, func(t *testing.T) {
query := initExtractTable(t, tt.input)
gots := ExtractInsertValues(query)
gots := ExtractInsertValues(query, tt.pos)

if len(gots) != len(tt.want) {
t.Errorf("contain nodes %d, got %d (%v)", len(tt.want), len(gots), gots)
Expand Down
Loading

0 comments on commit 472d06b

Please sign in to comment.