Skip to content
Vlad Ureche edited this page Jun 15, 2014 · 13 revisions

Multi-stage programming differs from partial evaluation in the context in which the program is transformed: For multi-stage programming (called staging from here on) the context is given by executing the program itself, including its side-effects, and extracting a lifted expression that can further be transformed into code. On the other hand, partial evaluation is performed by an external tool, which is only able to evaluate a set of predefined side-effect-free operations.

Therefore, staging brings new optimization opportunities compared to partial evaluation: it can burn through thick layers of abstraction and output specialized operations for the task at hand. Still, executing side-effecting code as part of staging can have an undesirable consequence: that the code generated for a staged expression may not be correct later, as the side-effecting operations return other results.

Let us take an example. Note: Although the Unified Data Representation paper describes the annotation as @lifted, the implementation actually uses @staged instead

$ cd ~/Workspace/staging-plugin/examples/
$ cat pow.scala 
package stagium
package examples
package pow

import scala.reflect.runtime.universe._

// A test for the power function
object Test {
  def main(args: Array[String]): Unit = {

    // this is a method with staged arguments
    def pow(e: Double @staged, p: Int): Double @staged =
      if ( p == 0 ) 1.0
      else if ( p % 2 == 1 ) e * pow(e, p - 1)
      else { // p % 2 == 0
        val x = pow(e, p/2)
        x * x
      }

    // and this is what we stage
    println("execute: " + execute(pow(3, 5)) + "\n")
    val fun1 = function1[Double, Double](e => pow(e, 5))
    println("fun1(3): " + fun1(3) + "\n")
    val fun2 = function2[Double, Double, Double]((e1, e2) => pow(e1, 5) * pow(e2, 5))
    println("fun2(3, 1): " + fun2(3, 1) + "\n")
  }
}
...

In this example, it would be very attractive to partially evaluate pow for the exponent p. This would allow a faster execution when the power is fixed ahead of time. This is where the @staged annotation comes into play: it marks next-stage values, namely the values that are not considered as constants in the partial evaluation (note that in staging, everything but the @staged values are considered staging-time constants)

These @staged-annotated values, with the help of the staging plugin runtime and transformation, gather the expressions they are involved in as an intermediate representation. From this intermediate representation, new, optimized code can be output, code which will execute significantly faster. The mechanism by which this has typically been done in Scala was using the scala-virtualized compiler with the lightweight modular staging framework. This plugin explores a different avenue of compiler-level staging by using the data representation transformation.

Running the following example yields:

$ st-scalac pow.scala -P:stagium:passive
$ st-scala stagium.examples.pow.Test
execute: 243.0

fun1(3): 243.0

fun2(3, 1): 243.0

$ st-scalac pow.scala
$ st-scala stagium.examples.pow.Test
Need to compile and run:
*********************************
{
  val x0: Double = 3.0 * 3.0
  val x1: Double = x0 * x0
  val x2: Double = 3.0 * x1
  x2: Double
} // end of code block of 3 instructions
*********************************
execute: null

Need to compile and run:
*********************************
{
  val x6: Double => Double =
    (a1: Double) =>    {
      val x3: Double = a1 * a1
      val x4: Double = x3 * x3
      val x5: Double = a1 * x4
      x5: Double
    } // end of code block of 3 instructions

  x6: Double => Double
} // end of function x6: Double => Double
*********************************
<function1 called>
fun1(3): 0.0

Need to compile and run:
*********************************
{
  val x14: (Double, Double) => Double =
    (a1: Double, a2: Double) =>    {
      val x10: Double = a2 * a2
      val x7: Double = a1 * a1
      val x11: Double = x10 * x10
      val x8: Double = x7 * x7
      val x12: Double = a2 * x11
      val x9: Double = a1 * x8
      val x13: Double = x9 * x12
      x13: Double
    } // end of code block of 7 instructions

  x14: (Double, Double) => Double
} // end of function x14: (Double, Double) => Double
*********************************
<function2 called>
fun2(3, 1): 0.0

In the first execution, the code was ran immediately, as the staging plugin was set to passive mode. In the second run, the intermediate representation was extracted and new code was output for the partially evaluated pow function.

NOTE: At the time this wiki page was written, the mechanism to compile and execute the staging-time code was not available, therefore the functions output 0.0 instead of the actual value. We plan to add this functionality.

Staging internals

How exactly did the code in function get generated? We use the same mechanism in the lightweight modular staging framework, namely relaying the method calls on staged values to infix methods, which in turn created the intermediate representation.

To see how this works, it is useful to output the entire pow.scala example:

$ cat pow.scala 
package stagium
package examples
package pow

import scala.reflect.runtime.universe._

// A test for the power function
object Test {
  def main(args: Array[String]): Unit = {

    // this is a method with staged arguments
    def pow(e: Double @staged, p: Int): Double @staged =
      if ( p == 0 ) 1.0
      else if ( p % 2 == 1 ) e * pow(e, p - 1)
      else { // p % 2 == 0
        val x = pow(e, p/2)
        x * x
      }

    // and this is what we stage
    println("execute: " + execute(pow(3, 5)) + "\n")
    val fun1 = function1[Double, Double](e => pow(e, 5))
    println("fun1(3): " + fun1(3) + "\n")
    val fun2 = function2[Double, Double, Double]((e1, e2) => pow(e1, 5) * pow(e2, 5))
    println("fun2(3, 1): " + fun2(3, 1) + "\n")
  }
}


// This is the support object for staging
object __staged {
  case class DoubleTimes(t1: Exp[Double], t2: Exp[Double]) extends Def[Double] {
    override def toString = t1 + " * " + t2
  }
  def infix_*(r: Exp[Double], oth: Exp[Double]): Exp[Double] =
    (r, oth) match {
      case (Con(1), _) => oth
      case (_, Con(1)) => r
      case _ => DoubleTimes(r, oth)
    }
}

As you can see, the example also defines infix_*, the method to which multiplication(*) gets redirected when the receiver is staged. The infix_* body contains a simple optimization: when a number is multiplied by 1, it remains the same number. You can see this in the first two cases of the match expression: Con(1) represents a constant of value 1, thus having an unit operand returns the other one.

The final translation of pow, when the staging plugin is active, is:

def pow(e: Exp[Double], p: Int): Exp[Double] =
  if ( p == 0 ) Con[Double](1.0)
  else if ( p % 2 == 1 ) infix_*(e, pow(e, p - 1))
  else { // p % 2 == 0
    val x = pow(e, p/2)
    infix_*(x, x)
  }

The type Exp marks next-stage values, such as Constants, Arguments or Symbols. These are all defined in the value class plugin runtime. Symbols correspond to definitions, thus allowing common subexpression elimination and other standard optimizations.

With these definitions, executing pow produces an intermediate representation of the staged expression. From there, the execute and function methods (which are rewritten to their implementations in staging.scala) trigger the stage method which, in turn creates the code and the output you have seen.

We won't go into many details about the staging part, since it is not our contribution and has been very well described in the lightweight modular staging framework. Instead, the next chapter will focus on how the staging plugin transforms the annotations into the code that stages the pow function.

Challenge

Can you think of an extra case for infix_*? (scroll down for the answer)




















Hint: we only treated the case where one of the two operands is a constant.




















What if the two operands are constants? We can add this case:

def infix_*(r: Exp[Double], oth: Exp[Double]): Exp[Double] =
  (r, oth) match {
    case (Con(x), Con(y) => Con(x * y)
    case (Con(1), _) => oth
    case (_, Con(1)) => r
    case _ => DoubleTimes(r, oth)
  }

And then, the earlier execution will result in:

$ st-scalac pow.scala
$ st-scala stagium.examples.pow.Test
Need to compile and run:
*********************************
{
  243.0: Double
} // end of code block of 0 instructions

*********************************
...

Which means the first call, where both the base and the exponent were known, evaluated down to a constant. This is how staging can further optimize the staged programs.

NOTE: The staging plugin in the DRT virtual machine contains an error and will output:

$ st-scala stagium.examples.pow.Test
Need to compile and run:
*********************************
scala.NotImplementedError: an implementation is missing
	at scala.Predef$.$qmark$qmark$qmark(Predef.scala:225)
	at stagium.staging$$anonfun$3.apply(staging.scala:77)
...

To fix this, do:

$ cd ~/Workspace/staging-plugin/
$ git pull origin master 
...
$ sbt package
[success] ...

See the next chapter to show the transformations that allow staging the pow function.

Clone this wiki locally