From 6bd8fdc47ddc1d2e1ee97785edbd6182ee54b385 Mon Sep 17 00:00:00 2001 From: k4n4ry Date: Mon, 3 Jun 2024 02:37:18 +0900 Subject: [PATCH 1/3] Handle assignments to global variables from the init function --- assertion/global/analyzer.go | 17 +++++++- assertion/global/globalvarinit.go | 40 +++++++++++++++++-- .../go.uber.org/globalvars/globalvarinit.go | 8 ++++ 3 files changed, 60 insertions(+), 5 deletions(-) diff --git a/assertion/global/analyzer.go b/assertion/global/analyzer.go index 972dc5e2..1240455a 100644 --- a/assertion/global/analyzer.go +++ b/assertion/global/analyzer.go @@ -49,6 +49,7 @@ func run(pass *analysis.Pass) ([]annotation.FullTrigger, error) { if !conf.IsFileInScope(file) { continue } + initFuncDecl := getInitFuncDecl(file) for _, decl := range file.Decls { genDecl, ok := decl.(*ast.GenDecl) @@ -56,10 +57,24 @@ func run(pass *analysis.Pass) ([]annotation.FullTrigger, error) { continue } for _, spec := range genDecl.Specs { - fullTriggers = append(fullTriggers, analyzeValueSpec(pass, spec.(*ast.ValueSpec))...) + fullTriggers = append(fullTriggers, analyzeValueSpec(pass, spec.(*ast.ValueSpec), initFuncDecl)...) } } } return fullTriggers, nil } + +// getInitFuncDecl searches for the init function declaration in the given *ast.File. +// It returns the *ast.FuncDecl representing the init function if found, or nil otherwise. +func getInitFuncDecl(file *ast.File) *ast.FuncDecl { + if file == nil { + return nil + } + for _, decl := range file.Decls { + if funcDecl, ok := decl.(*ast.FuncDecl); ok && funcDecl.Name.Name == "init" { + return funcDecl + } + } + return nil +} diff --git a/assertion/global/globalvarinit.go b/assertion/global/globalvarinit.go index 3c38ce3b..d253c7ff 100644 --- a/assertion/global/globalvarinit.go +++ b/assertion/global/globalvarinit.go @@ -24,10 +24,10 @@ import ( ) // analyzeValueSpec returns full triggers corresponding to the declaration -func analyzeValueSpec(pass *analysis.Pass, spec *ast.ValueSpec) []annotation.FullTrigger { +func analyzeValueSpec(pass *analysis.Pass, spec *ast.ValueSpec, initFuncDecl *ast.FuncDecl) []annotation.FullTrigger { var fullTriggers []annotation.FullTrigger - consumers := getGlobalConsumers(pass, spec) + consumers := getGlobalConsumers(pass, spec, initFuncDecl) for i, ident := range spec.Names { if consumers[i] == nil { @@ -63,12 +63,12 @@ func analyzeValueSpec(pass *analysis.Pass, spec *ast.ValueSpec) []annotation.Ful } // Returns a list of consumers corresponding to a global level variable declaration -func getGlobalConsumers(pass *analysis.Pass, valspec *ast.ValueSpec) []*annotation.ConsumeTrigger { +func getGlobalConsumers(pass *analysis.Pass, valspec *ast.ValueSpec, initFuncDecl *ast.FuncDecl) []*annotation.ConsumeTrigger { consumers := make([]*annotation.ConsumeTrigger, len(valspec.Names)) for i, name := range valspec.Names { // Types that are not nilable are eliminated here - if !util.TypeBarsNilness(pass.TypesInfo.TypeOf(name)) && !util.IsEmptyExpr(name) { + if !util.TypeBarsNilness(pass.TypesInfo.TypeOf(name)) && !util.IsEmptyExpr(name) && !hasGlobalVarAssignInInitFunc(valspec, initFuncDecl) { v := pass.TypesInfo.ObjectOf(name).(*types.Var) consumers[i] = &annotation.ConsumeTrigger{ Annotation: &annotation.GlobalVarAssign{ @@ -85,6 +85,38 @@ func getGlobalConsumers(pass *analysis.Pass, valspec *ast.ValueSpec) []*annotati return consumers } +// Checks if all the global variables represented by spec are assigned values within the init function. +// It returns true if all variables are assigned, false otherwise. +// If initFuncDecl is nil, it returns false. +func hasGlobalVarAssignInInitFunc(spec *ast.ValueSpec, initFuncDecl *ast.FuncDecl) bool { + if initFuncDecl == nil { + return false + } + assignedVars := make(map[string]bool) + for _, name := range spec.Names { + assignedVars[name.Name] = false + } + ast.Inspect(initFuncDecl.Body, func(node ast.Node) bool { + if assign, ok := node.(*ast.AssignStmt); ok { + for _, lhs := range assign.Lhs { + if ident, ok := lhs.(*ast.Ident); ok { + if _, exists := assignedVars[ident.Name]; exists { + assignedVars[ident.Name] = true + } + } + } + } + return true + }) + + for _, assigned := range assignedVars { + if !assigned { + return false + } + } + return true +} + // Returns a producer in the cases: 1) func call 2) literal nil 3) another global var 4) struct field/method. // In all other cases, it returns nil. func getGlobalProducer(pass *analysis.Pass, valspec *ast.ValueSpec, lid int, rid int) *annotation.ProduceTrigger { diff --git a/testdata/src/go.uber.org/globalvars/globalvarinit.go b/testdata/src/go.uber.org/globalvars/globalvarinit.go index ee2aab21..57ee3893 100644 --- a/testdata/src/go.uber.org/globalvars/globalvarinit.go +++ b/testdata/src/go.uber.org/globalvars/globalvarinit.go @@ -25,6 +25,14 @@ var x = 3 // This should throw an error since it is not initialized var noInit *int //want "assigned into global variable" +var _init *int +var _initMult1, _initMult2 *int + +func init() { + _init = new(int) + _initMult1 = new(int) + _initMult2 = new(int) +} // nilable(nilableVar) var nilableVar *int From 2450de2f9fa2e66fd0445264bc963db818e24944 Mon Sep 17 00:00:00 2001 From: k4n4ry Date: Thu, 13 Jun 2024 21:03:45 +0900 Subject: [PATCH 2/3] Refactor getInitFuncDecl to return a slice of *ast.FuncDecl --- assertion/global/analyzer.go | 11 ++++--- assertion/global/globalvarinit.go | 32 ++++++++++--------- .../go.uber.org/globalvars/globalvarinit.go | 6 ++++ 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/assertion/global/analyzer.go b/assertion/global/analyzer.go index 1240455a..5d85521e 100644 --- a/assertion/global/analyzer.go +++ b/assertion/global/analyzer.go @@ -49,7 +49,7 @@ func run(pass *analysis.Pass) ([]annotation.FullTrigger, error) { if !conf.IsFileInScope(file) { continue } - initFuncDecl := getInitFuncDecl(file) + initFuncDecls := getInitFuncDecls(file) for _, decl := range file.Decls { genDecl, ok := decl.(*ast.GenDecl) @@ -57,7 +57,7 @@ func run(pass *analysis.Pass) ([]annotation.FullTrigger, error) { continue } for _, spec := range genDecl.Specs { - fullTriggers = append(fullTriggers, analyzeValueSpec(pass, spec.(*ast.ValueSpec), initFuncDecl)...) + fullTriggers = append(fullTriggers, analyzeValueSpec(pass, spec.(*ast.ValueSpec), initFuncDecls)...) } } } @@ -67,14 +67,15 @@ func run(pass *analysis.Pass) ([]annotation.FullTrigger, error) { // getInitFuncDecl searches for the init function declaration in the given *ast.File. // It returns the *ast.FuncDecl representing the init function if found, or nil otherwise. -func getInitFuncDecl(file *ast.File) *ast.FuncDecl { +func getInitFuncDecls(file *ast.File) []*ast.FuncDecl { if file == nil { return nil } + var initFuncDecls []*ast.FuncDecl for _, decl := range file.Decls { if funcDecl, ok := decl.(*ast.FuncDecl); ok && funcDecl.Name.Name == "init" { - return funcDecl + initFuncDecls = append(initFuncDecls, funcDecl) } } - return nil + return initFuncDecls } diff --git a/assertion/global/globalvarinit.go b/assertion/global/globalvarinit.go index d253c7ff..c5135829 100644 --- a/assertion/global/globalvarinit.go +++ b/assertion/global/globalvarinit.go @@ -24,10 +24,10 @@ import ( ) // analyzeValueSpec returns full triggers corresponding to the declaration -func analyzeValueSpec(pass *analysis.Pass, spec *ast.ValueSpec, initFuncDecl *ast.FuncDecl) []annotation.FullTrigger { +func analyzeValueSpec(pass *analysis.Pass, spec *ast.ValueSpec, initFuncDecls []*ast.FuncDecl) []annotation.FullTrigger { var fullTriggers []annotation.FullTrigger - consumers := getGlobalConsumers(pass, spec, initFuncDecl) + consumers := getGlobalConsumers(pass, spec, initFuncDecls) for i, ident := range spec.Names { if consumers[i] == nil { @@ -63,12 +63,12 @@ func analyzeValueSpec(pass *analysis.Pass, spec *ast.ValueSpec, initFuncDecl *as } // Returns a list of consumers corresponding to a global level variable declaration -func getGlobalConsumers(pass *analysis.Pass, valspec *ast.ValueSpec, initFuncDecl *ast.FuncDecl) []*annotation.ConsumeTrigger { +func getGlobalConsumers(pass *analysis.Pass, valspec *ast.ValueSpec, initFuncDecls []*ast.FuncDecl) []*annotation.ConsumeTrigger { consumers := make([]*annotation.ConsumeTrigger, len(valspec.Names)) for i, name := range valspec.Names { // Types that are not nilable are eliminated here - if !util.TypeBarsNilness(pass.TypesInfo.TypeOf(name)) && !util.IsEmptyExpr(name) && !hasGlobalVarAssignInInitFunc(valspec, initFuncDecl) { + if !util.TypeBarsNilness(pass.TypesInfo.TypeOf(name)) && !util.IsEmptyExpr(name) && !hasGlobalVarAssignInInitFunc(valspec, initFuncDecls) { v := pass.TypesInfo.ObjectOf(name).(*types.Var) consumers[i] = &annotation.ConsumeTrigger{ Annotation: &annotation.GlobalVarAssign{ @@ -88,26 +88,28 @@ func getGlobalConsumers(pass *analysis.Pass, valspec *ast.ValueSpec, initFuncDec // Checks if all the global variables represented by spec are assigned values within the init function. // It returns true if all variables are assigned, false otherwise. // If initFuncDecl is nil, it returns false. -func hasGlobalVarAssignInInitFunc(spec *ast.ValueSpec, initFuncDecl *ast.FuncDecl) bool { - if initFuncDecl == nil { +func hasGlobalVarAssignInInitFunc(spec *ast.ValueSpec, initFuncDecls []*ast.FuncDecl) bool { + if len(initFuncDecls) == 0 { return false } assignedVars := make(map[string]bool) for _, name := range spec.Names { assignedVars[name.Name] = false } - ast.Inspect(initFuncDecl.Body, func(node ast.Node) bool { - if assign, ok := node.(*ast.AssignStmt); ok { - for _, lhs := range assign.Lhs { - if ident, ok := lhs.(*ast.Ident); ok { - if _, exists := assignedVars[ident.Name]; exists { - assignedVars[ident.Name] = true + for _, initFuncDecl := range initFuncDecls { + ast.Inspect(initFuncDecl.Body, func(node ast.Node) bool { + if assign, ok := node.(*ast.AssignStmt); ok { + for _, lhs := range assign.Lhs { + if ident, ok := lhs.(*ast.Ident); ok { + if _, exists := assignedVars[ident.Name]; exists { + assignedVars[ident.Name] = true + } } } } - } - return true - }) + return true + }) + } for _, assigned := range assignedVars { if !assigned { diff --git a/testdata/src/go.uber.org/globalvars/globalvarinit.go b/testdata/src/go.uber.org/globalvars/globalvarinit.go index 57ee3893..3c9aaaa8 100644 --- a/testdata/src/go.uber.org/globalvars/globalvarinit.go +++ b/testdata/src/go.uber.org/globalvars/globalvarinit.go @@ -34,6 +34,12 @@ func init() { _initMult2 = new(int) } +var _init2 *int + +func init() { + _init2 = new(int) +} + // nilable(nilableVar) var nilableVar *int var assignedNilable = nilableVar //want "assigned" From dce532a2509d6ebfb1452310abf809847b289370 Mon Sep 17 00:00:00 2001 From: k4n4ry Date: Fri, 21 Jun 2024 12:48:34 +0900 Subject: [PATCH 3/3] Refactor getInitFuncDecls to handle multiple init functions and related functions --- assertion/global/analyzer.go | 38 +++++++++++++++++-- .../go.uber.org/globalvars/globalvarinit.go | 15 ++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/assertion/global/analyzer.go b/assertion/global/analyzer.go index 5d85521e..4bc8762c 100644 --- a/assertion/global/analyzer.go +++ b/assertion/global/analyzer.go @@ -65,16 +65,48 @@ func run(pass *analysis.Pass) ([]annotation.FullTrigger, error) { return fullTriggers, nil } -// getInitFuncDecl searches for the init function declaration in the given *ast.File. -// It returns the *ast.FuncDecl representing the init function if found, or nil otherwise. +// getInitFuncDecls searches for the init function declarations and all related functions in the given *ast.File. +// It returns a slice of *ast.FuncDecl representing the init functions and all functions called directly or indirectly from them. +// The function handles multiple init functions if present, and avoids infinite recursion in case of cyclic function calls. +// If the file is nil, it returns nil. func getInitFuncDecls(file *ast.File) []*ast.FuncDecl { if file == nil { return nil } + funcDecls := make(map[string]*ast.FuncDecl) + for _, decl := range file.Decls { + if funcDecl, ok := decl.(*ast.FuncDecl); ok { + funcDecls[funcDecl.Name.Name] = funcDecl + } + } + var initFuncDecls []*ast.FuncDecl + // visitedFuncs tracks processed functions to prevent infinite recursion and duplicate processing + visitedFuncs := make(map[string]struct{}) + var findRelatedFuncs func(*ast.FuncDecl) + findRelatedFuncs = func(funcDecl *ast.FuncDecl) { + if _, visited := visitedFuncs[funcDecl.Name.Name]; visited { + return + } + initFuncDecls = append(initFuncDecls, funcDecl) + visitedFuncs[funcDecl.Name.Name] = struct{}{} + ast.Inspect(funcDecl.Body, func(n ast.Node) bool { + if callExpr, ok := n.(*ast.CallExpr); ok { + if ident, ok := callExpr.Fun.(*ast.Ident); ok { + if funcDecl, exists := funcDecls[ident.Name]; exists { + findRelatedFuncs(funcDecl) + } + } + } + return true + }) + } + for _, decl := range file.Decls { if funcDecl, ok := decl.(*ast.FuncDecl); ok && funcDecl.Name.Name == "init" { - initFuncDecls = append(initFuncDecls, funcDecl) + findRelatedFuncs(funcDecl) + // Reset visitedFuncs for each init function to ensure all related functions are processed + visitedFuncs = make(map[string]struct{}) } } return initFuncDecls diff --git a/testdata/src/go.uber.org/globalvars/globalvarinit.go b/testdata/src/go.uber.org/globalvars/globalvarinit.go index 3c9aaaa8..b3b64e9a 100644 --- a/testdata/src/go.uber.org/globalvars/globalvarinit.go +++ b/testdata/src/go.uber.org/globalvars/globalvarinit.go @@ -40,6 +40,21 @@ func init() { _init2 = new(int) } +var _init3, _init4 *int + +func init() { + init_next() +} + +func init_next() { + _init3 = new(int) + init_next_next() +} + +func init_next_next() { + _init4 = new(int) +} + // nilable(nilableVar) var nilableVar *int var assignedNilable = nilableVar //want "assigned"