-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
283 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
package nomutateglobal | ||
|
||
import ( | ||
"go/ast" | ||
"go/types" | ||
|
||
"golang.org/x/tools/go/analysis" | ||
) | ||
|
||
var Analyzer = &analysis.Analyzer{ | ||
Name: "noMutateGlobal", | ||
Doc: "prevents mutation of global variables", | ||
Run: run, | ||
} | ||
|
||
func run(pass *analysis.Pass) (any, error) { | ||
c := newGlobalCollector(pass) | ||
for _, f := range pass.Files { | ||
ast.Inspect(f, func(n ast.Node) bool { | ||
asmt, ok := n.(*ast.AssignStmt) | ||
if !ok { | ||
return true | ||
} | ||
for _, lhs := range asmt.Lhs { | ||
sel, ok := lhs.(*ast.SelectorExpr) | ||
if !ok { | ||
continue | ||
} | ||
if !c.contains(sel.X) { | ||
continue | ||
} | ||
pass.Reportf( | ||
asmt.Pos(), | ||
"This assignment might mutate a global variable. Lhs can be a pointer to a global variable.", | ||
) | ||
} | ||
c.visitAssignment(asmt) | ||
return true | ||
}) | ||
} | ||
return nil, nil | ||
} | ||
|
||
func isGlobalPointer(expr ast.Expr, pass *analysis.Pass) bool { | ||
t := pass.TypesInfo.TypeOf(expr) | ||
if t == nil { | ||
return false | ||
} | ||
if _, ok := t.Underlying().(*types.Pointer); !ok { | ||
return false | ||
} | ||
sel, ok := expr.(*ast.SelectorExpr) | ||
if !ok { | ||
return false | ||
} | ||
ident, ok := sel.X.(*ast.Ident) | ||
if !ok { | ||
return false | ||
} | ||
o := pass.TypesInfo.ObjectOf(ident) | ||
if o == nil { | ||
return false | ||
} | ||
_, ok = o.(*types.PkgName) | ||
return ok | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package nomutateglobal_test | ||
|
||
import ( | ||
"testing" | ||
|
||
"golang.org/x/tools/go/analysis/analysistest" | ||
|
||
"github.com/seiyab/gost/nomutateglobal" | ||
) | ||
|
||
func TestNoMutateGlobal(t *testing.T) { | ||
testdata := analysistest.TestData() | ||
analysistest.Run(t, testdata, nomutateglobal.Analyzer) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
package nomutateglobal | ||
|
||
import ( | ||
"go/ast" | ||
"go/types" | ||
"slices" | ||
|
||
"golang.org/x/tools/go/analysis" | ||
) | ||
|
||
type fieldSet struct { | ||
pass *analysis.Pass | ||
set map[types.Object]fieldSetField | ||
} | ||
|
||
type fieldSetField struct { | ||
// NOTE: use "" as self | ||
children map[string]fieldSetField | ||
} | ||
|
||
func newFieldSet(pass *analysis.Pass) fieldSet { | ||
return fieldSet{ | ||
pass: pass, | ||
set: make(map[types.Object]fieldSetField), | ||
} | ||
} | ||
|
||
func (s *fieldSet) add(sel *ast.SelectorExpr) bool { | ||
o, path, ok := s.retrievePath(sel) | ||
if !ok { | ||
return false | ||
} | ||
otr, ok := s.set[o] | ||
if !ok { | ||
otr = fieldSetField{children: make(map[string]fieldSetField)} | ||
s.set[o] = otr | ||
} | ||
for _, p := range path { | ||
if _, ok := otr.children[p]; !ok { | ||
otr.children[p] = fieldSetField{children: make(map[string]fieldSetField)} | ||
} | ||
otr = otr.children[p] | ||
} | ||
otr.children[""] = fieldSetField{} | ||
return true | ||
} | ||
|
||
func (s *fieldSet) has(sel *ast.SelectorExpr) bool { | ||
o, path, ok := s.retrievePath(sel) | ||
if !ok { | ||
return false | ||
} | ||
otr, ok := s.set[o] | ||
if !ok { | ||
return false | ||
} | ||
for _, p := range path { | ||
otr, ok = otr.children[p] | ||
if !ok { | ||
return false | ||
} | ||
} | ||
_, ok = otr.children[""] | ||
return ok | ||
} | ||
|
||
func (s *fieldSet) retrievePath(sel *ast.SelectorExpr) (types.Object, []string, bool) { | ||
currentSel := sel | ||
var revPath []string | ||
var id *ast.Ident | ||
for id == nil { | ||
revPath = append(revPath, currentSel.Sel.Name) | ||
switch x := currentSel.X.(type) { | ||
case *ast.SelectorExpr: | ||
currentSel = x | ||
case *ast.Ident: | ||
id = x | ||
default: | ||
return nil, nil, false | ||
} | ||
} | ||
o := s.pass.TypesInfo.ObjectOf(id) | ||
if o == nil { | ||
return nil, nil, false | ||
} | ||
slices.Reverse(revPath) | ||
path := revPath | ||
return o, path, true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
package nomutateglobal | ||
|
||
import ( | ||
"go/ast" | ||
"go/types" | ||
|
||
"golang.org/x/tools/go/analysis" | ||
) | ||
|
||
type key types.Object | ||
|
||
type globalCollector struct { | ||
pass *analysis.Pass | ||
identCollection map[key]struct{} | ||
fieldCollection fieldSet | ||
} | ||
|
||
func newGlobalCollector(pass *analysis.Pass) globalCollector { | ||
return globalCollector{ | ||
pass: pass, | ||
identCollection: make(map[key]struct{}), | ||
fieldCollection: newFieldSet(pass), | ||
} | ||
} | ||
|
||
func (c *globalCollector) visitAssignment(asmt *ast.AssignStmt) { | ||
for i, rhs := range asmt.Rhs { | ||
if !isGlobalPointer(rhs, c.pass) { | ||
continue | ||
} | ||
lhs := asmt.Lhs[i] | ||
switch lhs := lhs.(type) { | ||
case *ast.Ident: | ||
o := c.pass.TypesInfo.ObjectOf(lhs) | ||
if o == nil { | ||
continue | ||
} | ||
c.identCollection[key(o)] = struct{}{} | ||
case *ast.SelectorExpr: | ||
c.fieldCollection.add(lhs) | ||
} | ||
} | ||
} | ||
|
||
func (c *globalCollector) contains(expr ast.Expr) bool { | ||
switch expr := expr.(type) { | ||
case *ast.Ident: | ||
o := c.pass.TypesInfo.ObjectOf(expr) | ||
if o == nil { | ||
return false | ||
} | ||
_, ok := c.identCollection[key(o)] | ||
return ok | ||
case *ast.SelectorExpr: | ||
return c.fieldCollection.has(expr) | ||
default: | ||
return false | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
package testdata | ||
|
||
import "net/http" | ||
|
||
type MyClient struct { | ||
Client *http.Client | ||
} | ||
|
||
// https://pkg.go.dev/vuln/GO-2024-2618 | ||
func New(client *http.Client) *MyClient { | ||
c := &MyClient{} | ||
c.Client.Transport = nil | ||
if client != nil { | ||
c.Client = http.DefaultClient | ||
} | ||
c.Client.Timeout = 0 // want ".+" | ||
return c | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
package testdata | ||
|
||
import ( | ||
"flag" | ||
"net" | ||
"net/http" | ||
) | ||
|
||
var MyGlobal *struct { | ||
Field string | ||
} | ||
|
||
func _() { | ||
client := http.DefaultClient | ||
client.Transport = nil // want ".+" | ||
|
||
resolver := net.DefaultResolver | ||
resolver.PreferGo = true // want ".+" | ||
|
||
x := MyGlobal | ||
x.Field = "foo" | ||
} | ||
|
||
func _(b bool) { | ||
var f1, f2 *flag.FlagSet | ||
f1.Usage = nil | ||
f2.Usage = nil | ||
if b { | ||
f1 = flag.CommandLine | ||
} else { | ||
f2 = flag.CommandLine | ||
} | ||
f1.Usage = nil // want ".+" | ||
f2.Usage = nil // want ".+" | ||
} |