From e8da0c087aecac7bd88a27dbe626c7252cff6459 Mon Sep 17 00:00:00 2001 From: Conrad Zimmerman Date: Thu, 20 Jun 2024 13:27:12 -0700 Subject: [PATCH] Add pass to simplify multiple return statements --- .../scala/gvc/analyzer/ReturnValidator.scala | 2 - .../scala/gvc/transformer/IRTransformer.scala | 1 + .../transformer/ReturnSimplification.scala | 66 +++++++++++++++++++ .../resources/baseline/framing.baseline.c0 | 6 +- src/test/resources/ir/alloc_expr.ir.c0 | 6 +- src/test/resources/ir/logical.ir.c0 | 6 +- src/test/resources/ir/logical_2.ir.c0 | 6 +- src/test/resources/viper/char.vpr | 6 +- 8 files changed, 87 insertions(+), 12 deletions(-) create mode 100644 src/main/scala/gvc/transformer/ReturnSimplification.scala diff --git a/src/main/scala/gvc/analyzer/ReturnValidator.scala b/src/main/scala/gvc/analyzer/ReturnValidator.scala index dd9f95a..575ce23 100644 --- a/src/main/scala/gvc/analyzer/ReturnValidator.scala +++ b/src/main/scala/gvc/analyzer/ReturnValidator.scala @@ -5,8 +5,6 @@ import scala.annotation.tailrec object ReturnValidator { def validate(program: ResolvedProgram, errors: ErrorSink): Unit = { validateReturn(program, errors) - - // TODO: Do we want to check for early returns? validateTailReturn(program, errors) } diff --git a/src/main/scala/gvc/transformer/IRTransformer.scala b/src/main/scala/gvc/transformer/IRTransformer.scala index 0ce529a..a3fd836 100644 --- a/src/main/scala/gvc/transformer/IRTransformer.scala +++ b/src/main/scala/gvc/transformer/IRTransformer.scala @@ -292,6 +292,7 @@ object IRTransformer { input.declaration.postcondition.map(transformSpec(_, scope)) .orElse(Some(new IR.Imprecise(None))) + ReturnSimplification.transform(method) ReassignmentElimination.transform(method) ParameterAssignmentElimination.transform(method) } diff --git a/src/main/scala/gvc/transformer/ReturnSimplification.scala b/src/main/scala/gvc/transformer/ReturnSimplification.scala new file mode 100644 index 0000000..265c7e7 --- /dev/null +++ b/src/main/scala/gvc/transformer/ReturnSimplification.scala @@ -0,0 +1,66 @@ +package gvc.transformer + +// A pass that removes multiple return statements by transforming +// +// if (cond) return x; +// else return y; +// +// into +// +// RETURN_TYPE result; +// if (cond) result = x; +// else result = y; +// return result; +// +// This allows simpler handling of runtime checks for post-conditions. +// +// It also removes redundant `return` statements in the tail position of void +// methods. +// +// NOTE: This **assumes** that there are NO early returns! This assumption is +// checked by gvc.analyzer.ReturnValidator. + +object ReturnSimplification { + def transform(method: IR.Method): Unit = { + method.returnType match { + case None => removeReturns(method.body) + case Some(t) => method.body.lastOption match { + case Some(_: IR.Return) => + // Non-void method ending in a return statement is fine already + () + case _ => { + // Otherwise, add a new variable, change returns to assignments, and + // add the single return statement to the method + val result = method.addVar(t, "result") + returnToAssignment(method.body, result) + method.body += new IR.Return(Some(result)); + } + } + } + } + + private def removeReturns(block: IR.Block): Unit = { + block.lastOption match { + case Some(ret: IR.Return) => ret.remove() + case Some(cond: IR.If) => { + removeReturns(cond.ifFalse) + removeReturns(cond.ifTrue) + } + case _ => () + } + } + + private def returnToAssignment(block: IR.Block, result: IR.Var): Unit = + block.lastOption match { + case Some(ret: IR.Return) if ret.value.isDefined => { + ret.remove() + block += new IR.Assign(result, ret.value.get) + } + case Some(cond: IR.If) => { + returnToAssignment(cond.ifTrue, result) + returnToAssignment(cond.ifFalse, result) + } + case _ => + throw new TransformerException("Could not find return statement in non-void method") + } +} \ No newline at end of file diff --git a/src/test/resources/baseline/framing.baseline.c0 b/src/test/resources/baseline/framing.baseline.c0 index fbc5b89..6f8deff 100644 --- a/src/test/resources/baseline/framing.baseline.c0 +++ b/src/test/resources/baseline/framing.baseline.c0 @@ -56,6 +56,7 @@ int getValue(struct Outer* outer, struct OwnedFields* _ownedFields) int getValueSafe(struct Outer* outer, struct OwnedFields* _ownedFields) { + int result = 0; struct OwnedFields* _tempFields = NULL; if (outer != NULL) { @@ -65,12 +66,13 @@ int getValueSafe(struct Outer* outer, struct OwnedFields* _ownedFields) { assertAcc(_ownedFields, outer != NULL ? outer->_id : -1, 0, "Field access runtime check failed for struct Outer.inner"); assertAcc(_ownedFields, outer->inner != NULL ? outer->inner->_id : -1, 0, "Field access runtime check failed for struct Inner.value"); - return outer->inner->value; + result = outer->inner->value; } else { - return 0; + result = 0; } + return result; } int getValueStatic(struct Outer* outer, struct OwnedFields* _ownedFields) diff --git a/src/test/resources/ir/alloc_expr.ir.c0 b/src/test/resources/ir/alloc_expr.ir.c0 index 5f4dafa..b41e941 100644 --- a/src/test/resources/ir/alloc_expr.ir.c0 +++ b/src/test/resources/ir/alloc_expr.ir.c0 @@ -23,6 +23,7 @@ int main() struct _ptr_int* _2 = NULL; struct Test* _3 = NULL; struct Test* _4 = NULL; + int result = 0; _ = alloc(struct Test); _1 = alloc(struct _ptr_int); if (true) @@ -37,12 +38,13 @@ int main() } if (true && _4 == NULL) { - return 1; + result = 1; } else { - return 0; + result = 0; } + return result; } void something(struct Test* value) diff --git a/src/test/resources/ir/logical.ir.c0 b/src/test/resources/ir/logical.ir.c0 index d3fa8c5..492ff8b 100644 --- a/src/test/resources/ir/logical.ir.c0 +++ b/src/test/resources/ir/logical.ir.c0 @@ -7,16 +7,18 @@ int main() { bool a = false; bool b = false; + int result = 0; a = test(1); b = test(2); if (a || b) { - return 0; + result = 0; } else { - return 1; + result = 1; } + return result; } bool test(int value) diff --git a/src/test/resources/ir/logical_2.ir.c0 b/src/test/resources/ir/logical_2.ir.c0 index 99c1027..3f57036 100644 --- a/src/test/resources/ir/logical_2.ir.c0 +++ b/src/test/resources/ir/logical_2.ir.c0 @@ -7,17 +7,19 @@ int main() { bool b = false; bool c = false; + int result = 0; b = test(1); c = test(2); if (b || c) { - return 1; + result = 1; } else { b = !b; - return 1; + result = 1; } + return result; } bool test(int value) diff --git a/src/test/resources/viper/char.vpr b/src/test/resources/viper/char.vpr index 7a4d2a1..626f4b0 100644 --- a/src/test/resources/viper/char.vpr +++ b/src/test/resources/viper/char.vpr @@ -10,10 +10,12 @@ method main() returns ($result: Int) ensures true { var alpha: Bool + var _result$: Int alpha := isAlphabet(48) if (alpha) { - $result := 1 + _result$ := 1 } else { - $result := 0 + _result$ := 0 } + $result := _result$ } \ No newline at end of file