Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
sonalmahajan15 committed Jul 3, 2024
1 parent 1e79164 commit 16e396f
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 102 deletions.
5 changes: 3 additions & 2 deletions accumulation/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,10 @@ func checkErrors(triggers []annotation.FullTrigger, annMap annotation.Map, diagn
// Delete all "always safe" special handlers, since they are not meant to be tested for the no infer case
finalTriggers := make([]annotation.FullTrigger, 0, len(filteredTriggers))
for _, trigger := range filteredTriggers {
if _, ok := trigger.Consumer.Annotation.(*annotation.UseAsReturnForAlwaysSafePath); !ok {
finalTriggers = append(finalTriggers, trigger)
if c, ok := trigger.Consumer.Annotation.(*annotation.UseAsReturn); ok && c.IsTrackingAlwaysSafe {
continue
}
finalTriggers = append(finalTriggers, trigger)
}

for _, trigger := range finalTriggers {
Expand Down
76 changes: 4 additions & 72 deletions annotation/consume_trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -1092,15 +1092,17 @@ func DuplicateReturnConsumer(t *ConsumeTrigger, location token.Position) *Consum
// used for functions with contracts since we need to duplicate the sites for context sensitivity.
type UseAsReturn struct {
*TriggerIfNonNil
IsNamedReturn bool
RetStmt *ast.ReturnStmt
IsNamedReturn bool
IsTrackingAlwaysSafe bool
RetStmt *ast.ReturnStmt
}

// equals returns true if the passed ConsumingAnnotationTrigger is equal to this one
func (u *UseAsReturn) equals(other ConsumingAnnotationTrigger) bool {
if other, ok := other.(*UseAsReturn); ok {
return u.TriggerIfNonNil.equals(other.TriggerIfNonNil) &&
u.IsNamedReturn == other.IsNamedReturn &&
u.IsTrackingAlwaysSafe == other.IsTrackingAlwaysSafe &&
u.RetStmt == other.RetStmt
}
return false
Expand Down Expand Up @@ -1872,76 +1874,6 @@ func (f FldEscapePrestring) String() string {
return sb.String()
}

// UseAsReturnForAlwaysSafePath is when a value flows to a point where it is returned from a rich check effect function, namely,
// error retuning functions and ok-form functions. This consumer is used at the inference stage to determine if the
// rich check effect function is always safe to return a non-nil value.
type UseAsReturnForAlwaysSafePath struct {
*TriggerIfNonNil

IsNamedReturn bool
RetStmt *ast.ReturnStmt
}

// equals returns true if the passed ConsumingAnnotationTrigger is equal to this one
func (u *UseAsReturnForAlwaysSafePath) equals(other ConsumingAnnotationTrigger) bool {
if other, ok := other.(*UseAsReturnForAlwaysSafePath); ok {
return u.TriggerIfNonNil.equals(other.TriggerIfNonNil) &&
u.IsNamedReturn == other.IsNamedReturn &&
u.RetStmt == other.RetStmt
}
return false
}

// Copy returns a deep copy of this ConsumingAnnotationTrigger
func (u *UseAsReturnForAlwaysSafePath) Copy() ConsumingAnnotationTrigger {
copyConsumer := *u
copyConsumer.TriggerIfNonNil = u.TriggerIfNonNil.Copy().(*TriggerIfNonNil)
return &copyConsumer
}

// Prestring returns this UseAsNonErrorRetDependentOnErrorRetNilability as a Prestring
func (u *UseAsReturnForAlwaysSafePath) Prestring() Prestring {
retAnn := u.Ann.(*RetAnnotationKey)
return UseAsReturnForAlwaysSafePathPrestring{
retAnn.FuncDecl.Name(),
retAnn.RetNum,
retAnn.FuncDecl.Type().(*types.Signature).Results().At(retAnn.RetNum).Name(),
retAnn.FuncDecl.Type().(*types.Signature).Results().Len() - 1,
u.IsNamedReturn,
u.assignmentFlow.String(),
}
}

// UseAsReturnForAlwaysSafePathPrestring is a Prestring storing the needed information to compactly encode a UseAsReturnForAlwaysSafePath
type UseAsReturnForAlwaysSafePathPrestring struct {
FuncName string
RetNum int
RetName string
ErrRetNum int
IsNamedReturn bool
AssignmentStr string
}

func (u UseAsReturnForAlwaysSafePathPrestring) String() string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("returned from `%s()`", u.FuncName))
if u.IsNamedReturn {
sb.WriteString(fmt.Sprintf(" via named return `%s`", u.RetName))
} else {
sb.WriteString(fmt.Sprintf(" in position %d", u.RetNum))
}
sb.WriteString(u.AssignmentStr)
return sb.String()
}

// overriding position value to point to the raw return statement, which is the source of the potential error
func (u *UseAsReturnForAlwaysSafePath) customPos() (token.Pos, bool) {
if u.IsNamedReturn {
return u.RetStmt.Pos(), true
}
return 0, false
}

// UseAsNonErrorRetDependentOnErrorRetNilability is when a value flows to a point where it is returned from an error returning function
type UseAsNonErrorRetDependentOnErrorRetNilability struct {
*TriggerIfNonNil
Expand Down
1 change: 0 additions & 1 deletion annotation/consume_trigger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ var initStructsConsumingAnnotationTrigger = []any{
&UseAsErrorRetWithNilabilityUnknown{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
&ArgPassDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&UseAsReturnDeep{TriggerIfDeepNonNil: &TriggerIfDeepNonNil{Ann: newMockKey()}},
&UseAsReturnForAlwaysSafePath{TriggerIfNonNil: &TriggerIfNonNil{Ann: newMockKey()}},
}

// ConsumingAnnotationTriggerEqualsTestSuite tests for the `equals` method of all the structs that implement
Expand Down
5 changes: 3 additions & 2 deletions assertion/function/assertiontree/backprop.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,14 @@ func backpropAcrossReturn(rootNode *RootAssertionNode, node *ast.ReturnStmt) err
triggerAlwaysSafe := annotation.FullTrigger{
Producer: trigger.Producer,
Consumer: &annotation.ConsumeTrigger{
Annotation: &annotation.UseAsReturnForAlwaysSafePath{
Annotation: &annotation.UseAsReturn{
TriggerIfNonNil: &annotation.TriggerIfNonNil{
Ann: annotation.RetKeyFromRetNum(
rootNode.ObjectOf(rootNode.FuncNameIdent()).(*types.Func),
i,
)},
RetStmt: node,
RetStmt: node,
IsTrackingAlwaysSafe: true,
},
Expr: trigger.Consumer.Expr,
Guards: trigger.Consumer.Guards,
Expand Down
7 changes: 4 additions & 3 deletions assertion/function/assertiontree/backprop_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,15 +387,16 @@ func createReturnConsumersForAlwaysSafe(rootNode *RootAssertionNode, nonErrResul
}

rootNode.AddConsumption(&annotation.ConsumeTrigger{
Annotation: &annotation.UseAsReturnForAlwaysSafePath{
Annotation: &annotation.UseAsReturn{
TriggerIfNonNil: &annotation.TriggerIfNonNil{
Ann: &annotation.RetAnnotationKey{
FuncDecl: rootNode.FuncObj(),
RetNum: i,
},
},
IsNamedReturn: isNamedReturn,
RetStmt: retStmt},
IsNamedReturn: isNamedReturn,
IsTrackingAlwaysSafe: true,
RetStmt: retStmt},
Expr: nonErrResults[i],
Guards: util.NoGuards(),
})
Expand Down
4 changes: 3 additions & 1 deletion assertion/function/assertiontree/root_assertion_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ func (r *RootAssertionNode) AddConsumption(consumer *annotation.ConsumeTrigger)
path, producers := r.ParseExprAsProducer(consumer.Expr, false)
if path == nil { // expr is not trackable
if producers == nil {
if _, ok := consumer.Annotation.(*annotation.UseAsReturnForAlwaysSafePath); ok {
// Here we can infer that the expression is non-nil by definition. Instead of ignoring creation of a trigger,
// particularly for always safe tracking, we create a trigger with ProduceTriggerNever.
if c, ok := consumer.Annotation.(*annotation.UseAsReturn); ok && c.IsTrackingAlwaysSafe {
r.AddNewTriggers(annotation.FullTrigger{
Producer: &annotation.ProduceTrigger{
Annotation: &annotation.ProduceTriggerNever{},
Expand Down
44 changes: 23 additions & 21 deletions inference/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func (e *Engine) mapGuardMissingAndReturnToFuncSite(triggers []annotation.FullTr
}

for i, trigger := range triggers {
if c, ok := trigger.Consumer.Annotation.(*annotation.UseAsReturnForAlwaysSafePath); ok {
if c, ok := trigger.Consumer.Annotation.(*annotation.UseAsReturn); ok && c.IsTrackingAlwaysSafe {
site := e.primitive.site(c.UnderlyingSite(), c.Kind() == annotation.DeepConditional)
mapSiteReturn[site] = append(mapSiteReturn[site], i)
}
Expand All @@ -180,22 +180,11 @@ func (e *Engine) mapGuardMissingAndReturnToFuncSite(triggers []annotation.FullTr
// observeImplication. Before all assertions are sorted and handled thus, the annotations read for
// the package are iterated over and observed via calls to observeSiteExplanation as a <Val>BecauseAnnotation.
func (e *Engine) ObservePackage(pkgFullTriggers []annotation.FullTrigger) {
// Separate out triggers with UseAsNonErrorRetDependentOnErrorRetNilability consumer from other triggers.
// This is needed since whether UseAsNonErrorRetDependentOnErrorRetNilability triggers should be fired
// is dependent on their corresponding UseAsErrorRetWithNilabilityUnknown triggers. By this separation,
// we can process all other triggers, including UseAsErrorRetWithNilabilityUnknown, first, and once
// their nilability status is known, then filter out the unnecessary UseAsNonErrorRetDependentOnErrorRetNilability
// triggers, and run the pkg inference process again only for the remainder triggers.
// Steps 1--3 below depict this approach in more detail.
var (
nonErrRetTriggers []annotation.FullTrigger
// In most cases all triggers will be stored in otherTriggers, so we set a proper capacity.
otherTriggers = make([]annotation.FullTrigger, 0, len(pkgFullTriggers))
)

// Analyze "always safe" paths for rich check effect functions, namely error returning functions and ok-returning functions.
// The process is to find all guard missing triggers reaching function return sites, and then check if all the return triggers
// to the function site are non-nil. If so, we can safely delete all the guard-missing triggers for this function site.
// As Step 1, we do a pre-analysis of "guard missing" triggers to verify their dereferences are always safe,
// and hence can be safely deleted. Specifically, this analyis of "always safe" paths is focussed on the rich check
// effect functions, namely error returning functions and ok-returning functions. The process is to find all
// guard missing triggers reaching a function return site, and then check if all the return triggers
// to that function site are non-nil. If so, we can safely delete all the guard-missing triggers for this function site.
triggersToBeDeleted := make(map[int]bool)
mapSiteGuardMissing, mapSiteReturn := e.mapGuardMissingAndReturnToFuncSite(pkgFullTriggers)
for site, guardMissingIndices := range mapSiteGuardMissing {
Expand Down Expand Up @@ -227,6 +216,7 @@ func (e *Engine) ObservePackage(pkgFullTriggers []annotation.FullTrigger) {
}
}

// Filter out the triggers that are to be deleted.
var filteredPkgFullTriggers []annotation.FullTrigger
for i, t := range pkgFullTriggers {
if triggersToBeDeleted[i] {
Expand All @@ -236,6 +226,19 @@ func (e *Engine) ObservePackage(pkgFullTriggers []annotation.FullTrigger) {
}
pkgFullTriggers = filteredPkgFullTriggers

// Separate out triggers with UseAsNonErrorRetDependentOnErrorRetNilability consumer from other triggers.
// This is needed since whether UseAsNonErrorRetDependentOnErrorRetNilability triggers should be fired
// is dependent on their corresponding UseAsErrorRetWithNilabilityUnknown triggers. By this separation,
// we can process all other triggers, including UseAsErrorRetWithNilabilityUnknown, first, and once
// their nilability status is known, then filter out the unnecessary UseAsNonErrorRetDependentOnErrorRetNilability
// triggers, and run the pkg inference process again only for the remainder triggers.
// Steps 2--4 below depict this approach in more detail.
var (
nonErrRetTriggers []annotation.FullTrigger
// In most cases all triggers will be stored in otherTriggers, so we set a proper capacity.
otherTriggers = make([]annotation.FullTrigger, 0, len(pkgFullTriggers))
)

for _, t := range pkgFullTriggers {
if _, ok := t.Consumer.Annotation.(*annotation.UseAsNonErrorRetDependentOnErrorRetNilability); ok {
nonErrRetTriggers = append(nonErrRetTriggers, t)
Expand All @@ -244,10 +247,10 @@ func (e *Engine) ObservePackage(pkgFullTriggers []annotation.FullTrigger) {
}
}

// Step 1: build the inference map based on `otherTriggers` and incorporate those assertions into the `inferredAnnotationMap`
// Step 2: build the inference map based on `otherTriggers` and incorporate those assertions into the `inferredAnnotationMap`
e.buildPkgInferenceMap(otherTriggers)

// Step 2: run error return handling procedure to filter out redundant triggers based on the error contract, and
// Step 3: run error return handling procedure to filter out redundant triggers based on the error contract, and
// keep only those UseAsNonErrorRetDependentOnErrorRetNilability triggers that are not deleted.
// Call FilterTriggersForErrorReturn to filter triggers for error return handling -- inter-procedural and full-inference mode
_, delTriggers := assertiontree.FilterTriggersForErrorReturn(
Expand Down Expand Up @@ -298,7 +301,7 @@ func (e *Engine) ObservePackage(pkgFullTriggers []annotation.FullTrigger) {
}
}

// Step 3: run the inference building process for only the remaining UseAsNonErrorRetDependentOnErrorRetNilability triggers, and collect assertions
// Step 4: run the inference building process for only the remaining UseAsNonErrorRetDependentOnErrorRetNilability triggers, and collect assertions
e.buildPkgInferenceMap(filteredTriggers)
}

Expand Down Expand Up @@ -640,5 +643,4 @@ func GobRegister() {
gob.RegisterName(nextStr(), annotation.RecvPassPrestring{})
gob.RegisterName(nextStr(), annotation.MethodRecvDeepPrestring{})
gob.RegisterName(nextStr(), annotation.FldReturnPrestring{})
gob.RegisterName(nextStr(), annotation.UseAsReturnForAlwaysSafePathPrestring{})
}

0 comments on commit 16e396f

Please sign in to comment.