Skip to content

Commit

Permalink
Add pass to simplify multiple return statements
Browse files Browse the repository at this point in the history
  • Loading branch information
conradz committed Jun 20, 2024
1 parent dd30dbf commit e8da0c0
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 12 deletions.
2 changes: 0 additions & 2 deletions src/main/scala/gvc/analyzer/ReturnValidator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import scala.annotation.tailrec
object ReturnValidator {
def validate(program: ResolvedProgram, errors: ErrorSink): Unit = {
validateReturn(program, errors)

// TODO: Do we want to check for early returns?
validateTailReturn(program, errors)
}

Expand Down
1 change: 1 addition & 0 deletions src/main/scala/gvc/transformer/IRTransformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ object IRTransformer {
input.declaration.postcondition.map(transformSpec(_, scope))
.orElse(Some(new IR.Imprecise(None)))

ReturnSimplification.transform(method)
ReassignmentElimination.transform(method)
ParameterAssignmentElimination.transform(method)
}
Expand Down
66 changes: 66 additions & 0 deletions src/main/scala/gvc/transformer/ReturnSimplification.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package gvc.transformer

// A pass that removes multiple return statements by transforming
//
// if (cond) return x;
// else return y;
//
// into
//
// RETURN_TYPE result;
// if (cond) result = x;
// else result = y;
// return result;
//
// This allows simpler handling of runtime checks for post-conditions.
//
// It also removes redundant `return` statements in the tail position of void
// methods.
//
// NOTE: This **assumes** that there are NO early returns! This assumption is
// checked by gvc.analyzer.ReturnValidator.

object ReturnSimplification {
def transform(method: IR.Method): Unit = {
method.returnType match {
case None => removeReturns(method.body)
case Some(t) => method.body.lastOption match {
case Some(_: IR.Return) =>
// Non-void method ending in a return statement is fine already
()
case _ => {
// Otherwise, add a new variable, change returns to assignments, and
// add the single return statement to the method
val result = method.addVar(t, "result")
returnToAssignment(method.body, result)
method.body += new IR.Return(Some(result));
}
}
}
}

private def removeReturns(block: IR.Block): Unit = {
block.lastOption match {
case Some(ret: IR.Return) => ret.remove()
case Some(cond: IR.If) => {
removeReturns(cond.ifFalse)
removeReturns(cond.ifTrue)
}
case _ => ()
}
}

private def returnToAssignment(block: IR.Block, result: IR.Var): Unit =
block.lastOption match {
case Some(ret: IR.Return) if ret.value.isDefined => {
ret.remove()
block += new IR.Assign(result, ret.value.get)
}
case Some(cond: IR.If) => {
returnToAssignment(cond.ifTrue, result)
returnToAssignment(cond.ifFalse, result)
}
case _ =>
throw new TransformerException("Could not find return statement in non-void method")
}
}
6 changes: 4 additions & 2 deletions src/test/resources/baseline/framing.baseline.c0
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ int getValue(struct Outer* outer, struct OwnedFields* _ownedFields)

int getValueSafe(struct Outer* outer, struct OwnedFields* _ownedFields)
{
int result = 0;
struct OwnedFields* _tempFields = NULL;
if (outer != NULL)
{
Expand All @@ -65,12 +66,13 @@ int getValueSafe(struct Outer* outer, struct OwnedFields* _ownedFields)
{
assertAcc(_ownedFields, outer != NULL ? outer->_id : -1, 0, "Field access runtime check failed for struct Outer.inner");
assertAcc(_ownedFields, outer->inner != NULL ? outer->inner->_id : -1, 0, "Field access runtime check failed for struct Inner.value");
return outer->inner->value;
result = outer->inner->value;
}
else
{
return 0;
result = 0;
}
return result;
}

int getValueStatic(struct Outer* outer, struct OwnedFields* _ownedFields)
Expand Down
6 changes: 4 additions & 2 deletions src/test/resources/ir/alloc_expr.ir.c0
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ int main()
struct _ptr_int* _2 = NULL;
struct Test* _3 = NULL;
struct Test* _4 = NULL;
int result = 0;
_ = alloc(struct Test);
_1 = alloc(struct _ptr_int);
if (true)
Expand All @@ -37,12 +38,13 @@ int main()
}
if (true && _4 == NULL)
{
return 1;
result = 1;
}
else
{
return 0;
result = 0;
}
return result;
}

void something(struct Test* value)
Expand Down
6 changes: 4 additions & 2 deletions src/test/resources/ir/logical.ir.c0
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@ int main()
{
bool a = false;
bool b = false;
int result = 0;
a = test(1);
b = test(2);
if (a || b)
{
return 0;
result = 0;
}
else
{
return 1;
result = 1;
}
return result;
}

bool test(int value)
Expand Down
6 changes: 4 additions & 2 deletions src/test/resources/ir/logical_2.ir.c0
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@ int main()
{
bool b = false;
bool c = false;
int result = 0;
b = test(1);
c = test(2);
if (b || c)
{
return 1;
result = 1;
}
else
{
b = !b;
return 1;
result = 1;
}
return result;
}

bool test(int value)
Expand Down
6 changes: 4 additions & 2 deletions src/test/resources/viper/char.vpr
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ method main() returns ($result: Int)
ensures true
{
var alpha: Bool
var _result$: Int
alpha := isAlphabet(48)
if (alpha) {
$result := 1
_result$ := 1
} else {
$result := 0
_result$ := 0
}
$result := _result$
}

0 comments on commit e8da0c0

Please sign in to comment.