Skip to content

Commit

Permalink
Add support for Swap operation for array elements
Browse files Browse the repository at this point in the history
  • Loading branch information
jad-hamza committed Mar 25, 2021
1 parent 43aa5c7 commit cab7d37
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,37 @@ trait AntiAliasing
case ret @ Return(retExpr) =>
Return(Tuple(transform(retExpr, env) +: freshLocals.map(_.setPos(ret))).setPos(ret)).setPos(ret)

case swap @ Swap(array1, index1, array2, index2) =>
val base = array1.getType.asInstanceOf[ArrayType].base
val temp = ValDef.fresh("temp", base).setPos(swap)
val ra1 = exprOps.replaceFromSymbols(env.rewritings, array1)
val targets1 = getTargets(ra1)
val ra2 = exprOps.replaceFromSymbols(env.rewritings, array2)
val targets2 = getTargets(ra2)

if (targets1.exists(target => !env.bindings.contains(target.receiver.toVal)))
throw MalformedStainlessCode(swap, "Unsupported swap (first array)")

if (targets2.exists(target => !env.bindings.contains(target.receiver.toVal)))
throw MalformedStainlessCode(swap, "Unsupported swap (second array)")

val updates1 =
targets1.toSeq map { target =>
val applied = updatedTarget(target + ArrayAccessor(index1), ArraySelect(array2, index2).setPos(swap))
transform(Assignment(target.receiver, applied).setPos(swap), env)
}
val updates2 =
targets2.toSeq map { target =>
val applied = updatedTarget(target + ArrayAccessor(index2), temp.toVariable)
transform(Assignment(target.receiver, applied).setPos(swap), env)
}
val updates = updates1 ++ updates2
if (updates.isEmpty) UnitLiteral().setPos(swap)
else
Let(temp, transform(ArraySelect(array1, index1).setPos(swap), env),
Block(updates.init, updates.last).setPos(swap)
).setPos(swap)

case l @ Let(vd, e, b) if isMutableType(vd.tpe) =>
// see https://github.com/epfl-lara/stainless/pull/920 for discussion

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,14 +333,27 @@ trait EffectsAnalyzer extends oo.CachingPhase {
case ADTFieldAccessor(fid) +: rest =>
getTargets(args(symbols.getConstructor(id).fields.indexWhere(_.id == fid)), rest)
case _ =>
throw MalformedStainlessCode(expr, s"Couldn't compute effect targets in ADT: ${expr.asString}")
throw MalformedStainlessCode(expr,
s"Couldn't compute effect targets in ADT: ${expr.asString}\n" +
s"Path: ${path.map(_.asString)}")
}

case ClassConstructor(ct, args) => path match {
case ClassFieldAccessor(fid) +: rest =>
getTargets(args(ct.tcd.fields.indexWhere(_.id == fid)), rest)
case _ =>
throw MalformedStainlessCode(expr, s"Couldn't compute effect targets in class constructor: ${expr.asString}")
throw MalformedStainlessCode(expr,
s"Couldn't compute effect targets in class constructor: ${expr.asString}\n" +
s"Path: ${path.map(_.asString)}")
}

case FiniteArray(args, _) => path match {
case ArrayAccessor(Int32Literal(i)) +: rest =>
getTargets(args(i), rest)
case _ =>
throw MalformedStainlessCode(expr,
s"Couldn't compute effect targets in array: ${expr.asString}\n" +
s"Path: ${path.map(_.asString)}")
}

case Assert(_, _, e) => getTargets(e, path)
Expand Down Expand Up @@ -395,7 +408,10 @@ trait EffectsAnalyzer extends oo.CachingPhase {
}

case _ =>
throw MalformedStainlessCode(expr, s"Couldn't compute effect targets in (${expr.getClass}): ${expr.asString}")
throw MalformedStainlessCode(expr,
s"Couldn't compute effect targets in (${expr.getClass}): ${expr.asString}\n" +
s"Path: ${path.map(_.asString)}"
)
}

def getTargets(expr: Expr)(implicit symbols: Symbols): Set[Target] = {
Expand Down Expand Up @@ -456,6 +472,11 @@ trait EffectsAnalyzer extends oo.CachingPhase {
guard.toSeq.flatMap(rec(_, newEnv)).toSet ++ rec(rhs, newEnv)
}

case Swap(array1, index1, array2, index2) =>
rec(array1, env) ++ rec(index1, env) ++ rec(array2, env) ++ rec(index2, env) ++
effect(array1, env).map(_ + ArrayAccessor(index1)) ++
effect(array2, env).map(_ + ArrayAccessor(index2))

case ArrayUpdate(o, idx, v) =>
rec(o, env) ++ rec(idx, env) ++ rec(v, env) ++
effect(o, env).map(_ + ArrayAccessor(idx))
Expand Down
17 changes: 17 additions & 0 deletions core/src/main/scala/stainless/extraction/imperative/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ trait Trees extends oo.Trees with Definitions { self =>
if (expr.isTyped) NothingType() else Untyped
}

/** Swap indices from two (not necessarily distinct) arrays */
sealed case class Swap(array1: Expr, index1: Expr, array2: Expr, index2: Expr) extends Expr with CachingTyped {
override protected def computeType(implicit s: Symbols): Type =
(array1.getType, array2.getType) match {
case (ArrayType(base1), ArrayType(base2)) if base1 == base2 =>
checkParamTypes(Seq(index1, index2), Seq(Int32Type(), Int32Type()), UnitType())
case _ =>
Untyped
}
}

/** $encodingof `{ expr1; expr2; ...; exprn; last }` */
case class Block(exprs: Seq[Expr], last: Expr) extends Expr with CachingTyped {
protected def computeType(implicit s: Symbols): Type = if (exprs.forall(_.isTyped)) last.getType else Untyped
Expand Down Expand Up @@ -205,6 +216,9 @@ trait Printer extends oo.Printer {
case Return(e) =>
p"return $e"

case Swap(array1, index1, array2, index2) =>
p"swap($array1, $index1, $array2, $index2)"

case LetVar(vd, value, expr) =>
p"""|var $vd = $value
|$expr"""
Expand Down Expand Up @@ -327,6 +341,9 @@ trait TreeDeconstructor extends oo.TreeDeconstructor {
case s.Return(e) =>
(Seq(), Seq(), Seq(e), Seq(), Seq(), (_, _, es, _, _) => t.Return(es(0)))

case s.Swap(array1, index1, array2, index2) =>
(Seq(), Seq(), Seq(array1, index1, array2, index2), Seq(), Seq(), (_, _, es, _, _) => t.Swap(es(0), es(1), es(2), es(3)))

case s.MutableMapWithDefault(from, to, default) =>
(Seq(), Seq(), Seq(default), Seq(from, to), Seq(), (_, _, es, tps, _) => t.MutableMapWithDefault(tps(0), tps(1), es(0)))

Expand Down
30 changes: 30 additions & 0 deletions frontends/benchmarks/imperative/valid/Swap.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import stainless.lang.swap

object Swap {
def test(a1: Array[BigInt]): Unit = {
require(a1.length > 10 && a1(2) == 200 && a1(4) == 500)

swap(a1, 2, a1, 2)
assert(a1(2) == 200)
swap(a1, 2, a1, 4)
assert(a1(2) == 500)
assert(a1(4) == 200)
swap(a1, 2, a1, 4)
assert(a1(2) == 200)
assert(a1(4) == 500)
}

def test2: Unit = {
val a2 = Array(4, 8, 15, 16, 23, 42)

swap(a2, 1, a2, 1)
assert(a2(1) == 8)
swap(a2, 0, a2, 5)
assert(a2(0) == 42)
assert(a2(1) == 8)
assert(a2(2) == 15)
assert(a2(3) == 16)
assert(a2(4) == 23)
assert(a2(5) == 4)
}
}
25 changes: 25 additions & 0 deletions frontends/benchmarks/imperative/valid/Swap2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import stainless.lang.swap

object Swap2 {

case class Box(var x: BigInt)

def test: Unit = {
val a = Array[Box](Box(4), Box(8), Box(15), Box(16), Box(23), Box(42))
swap(a, 2, a, 3)
a(2).x *= 2
a(3).x += 2
assert(a(0).x == 4)
assert(a(1).x == 8)
assert(a(2).x == 32)
assert(a(3).x == 17)
assert(a(4).x == 23)
assert(a(5).x == 42)
swap(a, 2, a, 3)
assert(a(2).x == 17)
assert(a(3).x == 32)
swap(a, 4, a, 3)
assert(a(4).x == 32)
assert(a(3).x == 23)
}
}
11 changes: 11 additions & 0 deletions frontends/library/stainless/lang/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,15 @@ package object lang {
}
}

@ignore @library
def swap[@mutable T](a1: Array[T], i1: Int, a2: Array[T], i2: Int): Unit = {
require(
0 <= i1 && i1 < a1.length &&
0 <= i2 && i2 < a2.length
)
val t = a1(i1)
a1(i1) = a2(i2)
a2(i2) = t
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,16 @@ trait ASTExtractors {
}
}

object ExSwapExpression {
def unapply(tree: Apply) : Option[(Tree, Tree, Tree, Tree)] = tree match {
case a @ Apply(
TypeApply(ExSymbol("stainless", "lang", "swap"), _),
array1 :: index1 :: array2 :: index2 :: Nil) =>
Some((array1, index1, array2, index2))
case _ => None
}
}

object ExLambdaExpression {
def unapply(tree: Function) : Option[(Seq[ValDef], Tree)] = tree match {
case Function(vds, body) => Some((vds, body))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,9 @@ trait CodeExtraction extends ASTExtractors {
case _ => outOfSubsetError(tr, "Unexpected choose definition")
}

case swap @ ExSwapExpression(array1, index1, array2, index2) =>
xt.Swap(extractTree(array1), extractTree(index1), extractTree(array2), extractTree(index2))

case l @ ExLambdaExpression(args, body) =>
val vds = args map(vd => xt.ValDef(
FreshIdentifier(vd.symbol.name.toString),
Expand Down

0 comments on commit cab7d37

Please sign in to comment.