Skip to content

Commit

Permalink
feat: optimize dynamic sql (#685)
Browse files Browse the repository at this point in the history
* feat:optimization dynamic sql

* feat:update unit test

Co-authored-by: 卢章强 <[email protected]>
  • Loading branch information
idersec and 卢章强 authored Oct 21, 2022
1 parent 4fb174e commit c237629
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 95 deletions.
28 changes: 14 additions & 14 deletions internal/generate/clause_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ func TestClause(t *testing.T) {
GenerateResult: []string{
"generateSQL.WriteString(\"select * from users \")",
"var whereSQL0 strings.Builder",
"params[\"id\"] = id",
"whereSQL0.WriteString(\"id>@id \")",
"params = append(params,id)",
"whereSQL0.WriteString(\"id>? \")",
"helper.JoinWhereBuilder(&generateSQL,whereSQL0)",
},
},
Expand All @@ -87,8 +87,8 @@ func TestClause(t *testing.T) {
"generateSQL.WriteString(\"select * from users \")",
"var whereSQL0 strings.Builder",
"if id > 0 {",
"params[\"id\"] = id",
"whereSQL0.WriteString(\"id>@id \")",
"params = append(params,id)",
"whereSQL0.WriteString(\"id>? \")",
"}",
"helper.JoinWhereBuilder(&generateSQL,whereSQL0)",
},
Expand Down Expand Up @@ -116,17 +116,17 @@ func TestClause(t *testing.T) {
"generateSQL.WriteString(\"update users \")",
"var setSQL0 strings.Builder",
"if name != \"\" {",
"params[\"name\"] = name",
"setSQL0.WriteString(\"name=@name \")",
"params = append(params,name)",
"setSQL0.WriteString(\"name=? \")",
"}",
"setSQL0.WriteString(\", \")",
"if id>0 {",
"params[\"id\"] = id",
"setSQL0.WriteString(\"id=@id \")",
"params = append(params,id)",
"setSQL0.WriteString(\"id=? \")",
"}",
"helper.JoinSetBuilder(&generateSQL,setSQL0)",
"params[\"id\"] = id",
"generateSQL.WriteString(\"where id=@id \")",
"params = append(params,id)",
"generateSQL.WriteString(\"where id=? \")",
},
},
{
Expand All @@ -135,7 +135,7 @@ func TestClause(t *testing.T) {
"\"select * from \"",
"\"users\"",
"where",
"for _index, name := range names",
"for _, name := range names",
"\"name=\"",
"name",
"end",
Expand All @@ -144,9 +144,9 @@ func TestClause(t *testing.T) {
GenerateResult: []string{
"generateSQL.WriteString(\"select * from users \")",
"var whereSQL0 strings.Builder",
"for _index, name := range names{",
"params[\"nameForWhereSQL0_\"+strconv.Itoa(_index)]=name",
"whereSQL0.WriteString(\"name=@nameForWhereSQL0_\"+strconv.Itoa(_index)+\" \")",
"for _, name := range names{",
"params = append(params,name)",
"whereSQL0.WriteString(\"name=? \")",
"}",
"helper.JoinWhereBuilder(&generateSQL,whereSQL0)",
},
Expand Down
6 changes: 4 additions & 2 deletions internal/generate/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,11 @@ func (m *InterfaceMethod) checkParams(params []parser.Param) (err error) {
switch {
case param.Package == "UNDEFINED":
param.Package = m.Package
case param.IsMap() || param.IsGenM() || param.IsError() || param.IsNull():
case param.IsError() || param.IsNull():
return fmt.Errorf("type error on interface [%s] param: [%s]", m.InterfaceName, param.Name)
case param.IsGenM():
param.Type = "map[string]interface{}"
param.Package = ""
case param.IsGenT():
param.Type = m.OriginStruct.Type
param.Package = m.OriginStruct.Package
Expand Down Expand Up @@ -185,7 +188,6 @@ func (m *InterfaceMethod) checkResult(result []parser.Param) (err error) {
param.SetName("result")
param.Type = m.OriginStruct.Type
param.Package = m.OriginStruct.Package
param.IsPointer = true
m.ResultData = param
case param.IsInterface():
return fmt.Errorf("query method can not return interface in [%s.%s]", m.InterfaceName, m.MethodName)
Expand Down
82 changes: 18 additions & 64 deletions internal/generate/section.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,6 @@ func (s *Section) appendTmpl(value string) {
s.Tmpls = append(s.Tmpls, value)
}

func (s *Section) isInForValue(value string) (ForRange, bool) {
valueList := strings.Split(value, ".")
for _, v := range s.forValue {
if v.value == valueList[0] {
if len(valueList) > 1 {
v.suffix = "." + strings.Join(valueList[1:], ".")
}
return v, true
}
}
return ForRange{}, false
}

func (s *Section) hasSameName(value string) bool {
for _, p := range s.members {
if p.Type == model.FOR && p.ForRange.value == value {
Expand Down Expand Up @@ -438,15 +425,8 @@ func (s *Section) parseSQL(name string) (res SQLClause) {
case model.VARIABLE:
res.Value = append(res.Value, c.Value)
case model.DATA:
forRange, isInForRange := s.isInForValue(c.Value)
if isInForRange {
s.appendTmpl(forRange.appendDataToParams(c.Value, name))
c.Value = forRange.DataValue(c.Value, name)
} else {
s.appendTmpl(c.AddDataToParamMap())
c.Value = strconv.Quote("@" + c.SQLParamName())
}
res.Value = append(res.Value, c.Value)
s.appendTmpl(fmt.Sprintf("params = append(params,%s)", c.Value))
res.Value = append(res.Value, "\"?\"")
default:
s.SubIndex()
return
Expand All @@ -460,28 +440,24 @@ func (s *Section) parseSQL(name string) (res SQLClause) {

// checkSQLVar check sql variable by for loops value and external params
func (s *Section) checkSQLVar(param string, status model.Status, method *InterfaceMethod) (result section, err error) {
paramName := strings.Split(param, ".")[0]
for index, part := range s.members {
if part.Type == model.FOR && part.ForRange.value == paramName {
switch status {
case model.DATA:
method.HasForParams = true
if part.ForRange.index == "_" {
s.members[index].SetForRangeKey("_index")
}
case model.VARIABLE:
param = fmt.Sprintf("%s.Quote(%s)", method.S, param)
}
result = section{
Type: status,
Value: param,
}
return
if status == model.VARIABLE && param == "table" {
result = section{
Type: model.SQL,
Value: strconv.Quote(method.Table),
}

return
}

return method.checkSQLVarByParams(param, status)
if status == model.DATA {
method.HasForParams = true
}
if status == model.VARIABLE {
param = fmt.Sprintf("%s.Quote(%s)", method.S, param)
}
result = section{
Type: status,
Value: param,
}
return
}

// GetName ...
Expand Down Expand Up @@ -581,15 +557,6 @@ func (s *section) sectionType(str string) error {
return nil
}

func (s *section) SetForRangeKey(key string) {
s.ForRange.index = key
s.Value = s.String()
}

func (s *section) AddDataToParamMap() string {
return fmt.Sprintf("params[%q] = %s", s.SQLParamName(), s.Value)
}

func (s *section) SQLParamName() string {
return strings.Replace(s.Value, ".", "", -1)
}
Expand All @@ -605,16 +572,3 @@ type ForRange struct {
func (f *ForRange) String() string {
return fmt.Sprintf("for %s, %s := range %s", f.index, f.value, f.rangeList)
}

func (f *ForRange) mapIndexName(prefix, dataName, clauseName string) string {
return fmt.Sprintf("\"%s%sFor%s_\"+strconv.Itoa(%s)", prefix, strings.Replace(dataName, ".", "", -1), strings.Title(clauseName), f.index)
}

// DataValue return data value
func (f *ForRange) DataValue(dataName, clauseName string) string {
return f.mapIndexName("@", dataName, clauseName)
}

func (f *ForRange) appendDataToParams(dataName, clauseName string) string {
return fmt.Sprintf("params[%s]=%s%s", f.mapIndexName("", dataName, clauseName), f.value, f.suffix)
}
16 changes: 7 additions & 9 deletions internal/parser/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,16 @@ func GetInterfacePath(v interface{}) (paths []*InterfacePath, err error) {
path.Name = n
}

ctx := build.Default
var p *build.Package

if strings.Split(arg.String(), ".")[0] == "main" {
_, file, _, ok := runtime.Caller(3)
if ok {
path.Files = append(path.Files, file)
}
paths = append(paths, &path)
continue
_, file, _, _ := runtime.Caller(3)
p, err = ctx.ImportDir(filepath.Dir(file), build.ImportComment)
} else {
p, err = ctx.Import(arg.PkgPath(), "", build.ImportComment)
}

ctx := build.Default
var p *build.Package
p, err = ctx.Import(arg.PkgPath(), "", build.ImportComment)
if err != nil {
return
}
Expand Down
9 changes: 3 additions & 6 deletions internal/template/method.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,16 @@ const DIYMethod = `
//{{.DocComment }}
func ({{.S}} {{.TargetStruct}}Do){{.FuncSign}}{
{{if .HasSQLData}}params :=make(map[string]interface{},0)
{{if .HasSQLData}}var params []interface{}
{{end}}var generateSQL strings.Builder
{{range $line:=.Section.Tmpls}}{{$line}}
{{end}}
{{if .HasNeedNewResult}}result ={{if .ResultData.IsMap}}make{{else}}new{{end}}({{if ne .ResultData.Package ""}}{{.ResultData.Package}}.{{end}}{{.ResultData.Type}}){{end}}
{{if or .ReturnRowsAffected .ReturnError}}var executeSQL *gorm.DB
{{end}}{{if .HasSQLData}}if len(params)>0{
{{if or .ReturnRowsAffected .ReturnError}}executeSQL{{else}}_{{end}}= {{.S}}.UnderlyingDB().{{.GormOption}}(generateSQL.String(){{if .HasSQLData}},params{{end}}){{if not .ResultData.IsNull}}.{{.GormRunMethodName}}({{if .HasGotPoint}}&{{end}}{{.ResultData.Name}}){{end}}
}else{
{{if or .ReturnRowsAffected .ReturnError}}executeSQL{{else}}_{{end}}= {{.S}}.UnderlyingDB().{{.GormOption}}(generateSQL.String()){{if not .ResultData.IsNull}}.{{.GormRunMethodName}}({{if .HasGotPoint}}&{{end}}{{.ResultData.Name}}){{end}}
}{{else}}{{if or .ReturnRowsAffected .ReturnError}}executeSQL{{else}}_{{end}}= {{.S}}.UnderlyingDB().{{.GormOption}}(generateSQL.String()){{if not .ResultData.IsNull}}.{{.GormRunMethodName}}({{if .HasGotPoint}}&{{end}}{{.ResultData.Name}}){{end}}{{end}}
{{end}}
{{if or .ReturnRowsAffected .ReturnError}}executeSQL{{else}}_{{end}} = {{.S}}.UnderlyingDB().{{.GormOption}}(generateSQL.String(){{if .HasSQLData}},params...{{end}}){{if not .ResultData.IsNull}}.{{.GormRunMethodName}}({{if .HasGotPoint}}&{{end}}{{.ResultData.Name}}){{end}}
{{if .ReturnRowsAffected}}rowsAffected = executeSQL.RowsAffected
{{end}}{{if .ReturnError}}err = executeSQL.Error
{{end}}return
Expand Down

0 comments on commit c237629

Please sign in to comment.