Skip to content

Commit

Permalink
feat(dbm-services): 分析语法文件里面关联的具体表,细化模拟执行的范围 #8925
Browse files Browse the repository at this point in the history
  • Loading branch information
ymakedaq authored and zhangzhw8 committed Jan 10, 2025
1 parent 035f68f commit 727a5ad
Show file tree
Hide file tree
Showing 17 changed files with 357 additions and 60 deletions.
114 changes: 95 additions & 19 deletions dbm-services/mysql/db-simulation/app/syntax/parse_relation_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"runtime/debug"
Expand All @@ -29,7 +30,8 @@ import (
const AnalyzeConcurrency = 10

// DoParseRelationDbs parse relation db from sql file
func (tf *TmysqlParseFile) DoParseRelationDbs(version string) (createDbs, relationDbs []string, dumpAll bool,
func (tf *TmysqlParseFile) DoParseRelationDbs(version string) (createDbs, relationDbs, allCommands []string,
dumpAll bool,
err error) {
logger.Info("doing....")
tf.result = make(map[string]*CheckInfo)
Expand All @@ -39,15 +41,13 @@ func (tf *TmysqlParseFile) DoParseRelationDbs(version string) (createDbs, relati
if !tf.IsLocalFile {
if err = tf.Init(); err != nil {
logger.Error("Do init failed %s", err.Error())
return nil, nil, false, err
return nil, nil, nil, false, err
}
if err = tf.Downloadfile(); err != nil {
logger.Error("failed to download sql file from the product library %s", err.Error())
return nil, nil, false, err
return nil, nil, nil, false, err
}
}
// 最后删除临时目录,不会返回错误
defer tf.delTempDir()
logger.Info("all sqlfiles download ok ~")
alreadExecutedSqlfileChan := make(chan string, len(tf.Param.FileNames))

Expand All @@ -59,25 +59,26 @@ func (tf *TmysqlParseFile) DoParseRelationDbs(version string) (createDbs, relati
}()

logger.Info("start to analyze the parsing result")
createDbs, relationDbs, dumpAll, err = tf.doParseInchan(alreadExecutedSqlfileChan, version)
createDbs, relationDbs, allCommands, dumpAll, err = tf.doParseInchan(alreadExecutedSqlfileChan, version)
if err != nil {
logger.Error("failed to analyze the parsing result:%s", err.Error())
return nil, nil, false, err
return nil, nil, nil, false, err
}
logger.Info("createDbs:%v,relationDbs:%v,dumpAll:%v,err:%v", createDbs, relationDbs, dumpAll, err)
logger.Info("createDbs:%v,relationDbs:%v,allcomands%v,dumpAll:%v,err:%v", createDbs, relationDbs, allCommands, dumpAll,
err)
dumpdbs := []string{}
for _, d := range relationDbs {
if slices.Contains(createDbs, d) {
continue
}
dumpdbs = append(dumpdbs, d)
}
return lo.Uniq(createDbs), lo.Uniq(dumpdbs), dumpAll, nil
return lo.Uniq(createDbs), lo.Uniq(dumpdbs), lo.Uniq(allCommands), dumpAll, nil
}

// doParseInchan RelationDbs do parse relation db
func (t *TmysqlParse) doParseInchan(alreadExecutedSqlfileCh chan string,
mysqlVersion string) (createDbs []string, relationDbs []string, dumpAll bool, err error) {
mysqlVersion string) (createDbs []string, relationDbs []string, allCommands []string, dumpAll bool, err error) {
var errs []error
c := make(chan struct{}, AnalyzeConcurrency)
errChan := make(chan error)
Expand All @@ -90,21 +91,23 @@ func (t *TmysqlParse) doParseInchan(alreadExecutedSqlfileCh chan string,
c <- struct{}{}
go func(fileName string) {
defer wg.Done()
cdbs, dbs, dumpAllDbs, err := t.analyzeRelationDbs(fileName, mysqlVersion)
cdbs, dbs, commands, dumpAllDbs, err := t.analyzeRelationDbs(fileName, mysqlVersion)
logger.Info("createDbs:%v,dbs:%v,dumpAllDbs:%v,err:%v", cdbs, dbs, dumpAllDbs, err)
if err != nil {
logger.Error("analyzeRelationDbs failed %s", err.Error())
errChan <- err
return
}
// 如果有dumpall 则直接返回退出,不在继续分析
if dumpAllDbs {
dumpAll = true
<-c
wg.Done()
stopChan <- struct{}{}
}
t.mu.Lock()
relationDbs = append(relationDbs, dbs...)
createDbs = append(createDbs, cdbs...)
allCommands = append(allCommands, commands...)
t.mu.Unlock()
<-c
}(sqlfile)
Expand All @@ -121,7 +124,7 @@ func (t *TmysqlParse) doParseInchan(alreadExecutedSqlfileCh chan string,
case err := <-errChan:
errs = append(errs, err)
case <-stopChan:
return createDbs, relationDbs, dumpAll, errors.Join(errs...)
return createDbs, relationDbs, allCommands, dumpAll, errors.Join(errs...)
}
}
}
Expand All @@ -130,6 +133,7 @@ func (t *TmysqlParse) doParseInchan(alreadExecutedSqlfileCh chan string,
func (t *TmysqlParse) analyzeRelationDbs(inputfileName, mysqlVersion string) (
createDbs []string,
relationDbs []string,
allCommandType []string,
dumpAll bool,
err error) {
defer func() {
Expand All @@ -140,7 +144,7 @@ func (t *TmysqlParse) analyzeRelationDbs(inputfileName, mysqlVersion string) (
f, err := os.Open(t.getAbsoutputfilePath(inputfileName, mysqlVersion))
if err != nil {
logger.Error("open file failed %s", err.Error())
return nil, nil, false, err
return nil, nil, nil, false, err
}
defer f.Close()
reader := bufio.NewReader(f)
Expand All @@ -151,24 +155,27 @@ func (t *TmysqlParse) analyzeRelationDbs(inputfileName, mysqlVersion string) (
break
}
logger.Error("read Line Error %s", errx.Error())
return nil, nil, false, errx
return nil, nil, nil, false, errx
}
if len(line) == 1 && line[0] == byte('\n') {
continue
}
var res ParseLineQueryBase
if err = json.Unmarshal(line, &res); err != nil {
logger.Error("json unmasrshal line:%s failed %s", string(line), err.Error())
return nil, nil, false, err
return nil, nil, nil, false, err
}
// 判断是否有语法错误
if res.ErrorCode != 0 {
return nil, nil, false, err
return nil, nil, nil, false, fmt.Errorf("%s", res.ErrorMsg)
}
if lo.IsNotEmpty(res.Command) {
allCommandType = append(allCommandType, res.Command)
}
if slices.Contains([]string{SQLTypeCreateProcedure, SQLTypeCreateFunction, SQLTypeCreateView, SQLTypeCreateTrigger,
SQLTypeInsertSelect, SQLTypeRelaceSelect},
res.Command) {
return nil, nil, true, nil
return nil, nil, nil, true, nil
}
if lo.IsEmpty(res.DbName) {
continue
Expand All @@ -181,5 +188,74 @@ func (t *TmysqlParse) analyzeRelationDbs(inputfileName, mysqlVersion string) (
relationDbs = append(relationDbs, res.DbName)

}
return createDbs, relationDbs, false, nil
return createDbs, relationDbs, allCommandType, false, nil
}

// ParseSpecialTbls parse special tables
func (tf *TmysqlParseFile) ParseSpecialTbls(mysqlVersion string) (relationTbls []RelationTbl, err error) {
m := make(map[string][]string)
for _, fileName := range tf.Param.FileNames {
mm, err := tf.parseSpecialSQLFile(fileName, mysqlVersion)
if err != nil {
logger.Error("parseAlterSQLFile failed %s", err.Error())
return nil, err
}
for k, v := range mm {
m[k] = append(m[k], v...)
}
}
for k, v := range m {
relationTbls = append(relationTbls, RelationTbl{
DbName: k,
Tbls: v,
})
}
return relationTbls, nil
}

// RelationTbl dunmp db and table
type RelationTbl struct {
DbName string `json:"db_name"`
Tbls []string `json:"tbls"`
}

// parseSpecialSQLFile 解析指定库表
func (t *TmysqlParse) parseSpecialSQLFile(inputfileName, mysqlVersion string) (m map[string][]string, err error) {
f, err := os.Open(t.getAbsoutputfilePath(inputfileName, mysqlVersion))
if err != nil {
logger.Error("open file failed %s", err.Error())
return nil, err
}
m = make(map[string][]string)
defer f.Close()
reader := bufio.NewReader(f)
for {
line, errx := reader.ReadBytes(byte('\n'))
if errx != nil {
if errx == io.EOF {
break
}
logger.Error("read Line Error %s", errx.Error())
return nil, errx
}
if len(line) == 1 && line[0] == byte('\n') {
continue
}
var baseRes ParseIncludeTableBase
if err = json.Unmarshal(line, &baseRes); err != nil {
logger.Error("json unmasrshal line:%s failed %s", string(line), err.Error())
return nil, err
}
dbName := ""
if baseRes.Command == SQLTypeUseDb {
dbName = baseRes.DbName
}
if lo.IsNotEmpty(baseRes.DbName) {
dbName = baseRes.DbName
}
if lo.IsNotEmpty(baseRes.TableName) {
m[dbName] = append(m[dbName], baseRes.TableName)
}
}
return m, nil
}
7 changes: 4 additions & 3 deletions dbm-services/mysql/db-simulation/app/syntax/syntax.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (tf *TmysqlParseFile) Do(dbtype string, versions []string) (result map[stri
}
}
// 最后删除临时目录,不会返回错误
defer tf.delTempDir()
defer tf.DelTempDir()

var errs []error
for _, version := range versions {
Expand Down Expand Up @@ -184,7 +184,7 @@ func (tf *TmysqlParseFile) CreateAndUploadDDLTblFile() (err error) {
}
// 最后删除临时目录,不会返回错误
// 暂时屏蔽 观察过程文件
defer tf.delTempDir()
defer tf.DelTempDir()

if err = tf.Downloadfile(); err != nil {
logger.Error("failed to download sql file from the product library %s", err.Error())
Expand Down Expand Up @@ -248,7 +248,8 @@ func (t *TmysqlParse) Init() (err error) {
return nil
}

func (t *TmysqlParse) delTempDir() {
// DelTempDir TODO
func (t *TmysqlParse) DelTempDir() {
if err := os.RemoveAll(t.tmpWorkdir); err != nil {
logger.Warn("remove tempDir:" + t.tmpWorkdir + ".error info:" + err.Error())
}
Expand Down
12 changes: 12 additions & 0 deletions dbm-services/mysql/db-simulation/app/syntax/tmysqlpase_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ const (
SQLTypeInsertSelect = "insert_select"
// SQLTypeRelaceSelect replace select sql
SQLTypeRelaceSelect = "replace_select"
// SQLTypeDropTable drop table sql
SQLTypeDropTable = "drop_table"
// SQLTypeCreateIndex is creat table sql
SQLTypeCreateIndex = "create_index"
)

// NotAllowedDefaulValColMap 不允许默认值的字段
Expand Down Expand Up @@ -309,3 +313,11 @@ type UpdateResult struct {
HasWhere bool `json:"has_where"`
Limit int `json:"limit"`
}

// ParseIncludeTableBase parse include table
type ParseIncludeTableBase struct {
QueryID int `json:"query_id"`
Command string `json:"command"`
DbName string `json:"db_name"`
TableName string `json:"table_name"`
}
52 changes: 50 additions & 2 deletions dbm-services/mysql/db-simulation/handler/syntax_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"time"

"github.com/gin-gonic/gin"
"github.com/samber/lo"
"github.com/spf13/viper"

"dbm-services/common/go-pubpkg/cmutil"
Expand Down Expand Up @@ -251,11 +252,29 @@ func (s SyntaxHandler) ParseSQLFileRelationDb(r *gin.Context) {
FileNames: param.Files,
},
}
createDbs, dbs, dumpall, err := p.DoParseRelationDbs("")
createDbs, dbs, allCommands, dumpall, err := p.DoParseRelationDbs("")
if err != nil {
s.SendResponse(r, err, nil)
return
}
// 如果所有的命令都是alter table, dump指定库表
logger.Debug("debug: %v,%d", allCommands, len(allCommands))
if isAllOperateTable(allCommands) || isAllCreateTable(allCommands) {
relationTbls, err := p.ParseSpecialTbls("")
if err != nil {
s.SendResponse(r, err, nil)
return
}
s.SendResponse(r, nil, gin.H{
"create_dbs": createDbs,
"dbs": dbs,
"dump_all": false,
"just_dump_special_tbls": true,
"special_tbls": relationTbls,
"timestamp": time.Now().Unix(),
})
return
}

s.SendResponse(r, nil, gin.H{
"create_dbs": createDbs,
Expand All @@ -265,6 +284,15 @@ func (s SyntaxHandler) ParseSQLFileRelationDb(r *gin.Context) {
})
}

func isAllOperateTable(allCommands []string) bool {
return lo.Every([]string{syntax.SQLTypeAlterTable, syntax.SQLTypeUseDb,
syntax.SQLTypeCreateIndex, syntax.SQLTypeDropTable}, allCommands)
}

func isAllCreateTable(allCommands []string) bool {
return lo.Every([]string{syntax.SQLTypeCreateTable, syntax.SQLTypeUseDb}, allCommands)
}

// ParseSQLRelationDb 语法检查入参SQL string
func (s *SyntaxHandler) ParseSQLRelationDb(r *gin.Context) {
var param CheckSQLStringParam
Expand Down Expand Up @@ -298,11 +326,31 @@ func (s *SyntaxHandler) ParseSQLRelationDb(r *gin.Context) {
FileNames: []string{fileName},
},
}
createDbs, dbs, dumpall, err := p.DoParseRelationDbs("")
// defer p.DelTempDir()
createDbs, dbs, allCommands, dumpall, err := p.DoParseRelationDbs("")
if err != nil {
s.SendResponse(r, err, nil)
return
}
// 如果所有的命令都是alter table, dump指定库表
logger.Info("make debug: %v,%d", allCommands, len(allCommands))
if isAllOperateTable(allCommands) || isAllCreateTable(allCommands) {
relationTbls, err := p.ParseSpecialTbls("")
if err != nil {
s.SendResponse(r, err, nil)
return
}
s.SendResponse(r, nil, gin.H{
"create_dbs": createDbs,
"dbs": dbs,
"dump_all": false,
"just_dump_special_tbls": true,
"special_tbls": relationTbls,
"timestamp": time.Now().Unix(),
})
return
}

s.SendResponse(r, nil, gin.H{
"create_dbs": createDbs,
"dbs": dbs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ func (m MysqlUpgradeComp) mysqlUpgrade(conn *native.DbWorker, port int) (err err
return nil
}
// open general_log
if errx := m.openGeneralLog(conn); err != nil {
if errx := m.openGeneralLog(conn); errx != nil {
logger.Warn("set global general_log=on failed %s", errx.Error())
}
upgradeScript := ""
Expand Down
Loading

0 comments on commit 727a5ad

Please sign in to comment.