-
Notifications
You must be signed in to change notification settings - Fork 0
Introduction
Multi-stage programming differs from partial evaluation in the context in which the program is transformed: For multi-stage programming (called staging from here on) the context is given by executing the program itself, including its side-effects, and extracting a lifted expression that can further be transformed into code. On the other hand, partial evaluation is performed by an external tool, which is only able to evaluate a set of predefined side-effect-free operations.
Therefore, staging brings new optimization opportunities compared to partial evaluation: it can burn through thick layers of abstraction and output specialized operations for the task at hand. Still, executing side-effecting code as part of staging can have an undesirable consequence: that the code generated for a staged expression may not be correct later, as the side-effecting operations return other results.
Let us take an example.
Note: Although the Unified Data Representation paper describes the annotation as @lifted
, the implementation actually uses @staged
instead
$ cd ~/Workspace/staging-plugin/examples/
$ cat pow.scala
package stagium
package examples
package pow
import scala.reflect.runtime.universe._
// A test for the power function
object Test {
def main(args: Array[String]): Unit = {
// this is a method with staged arguments
def pow(e: Double @staged, p: Int): Double @staged =
if ( p == 0 ) 1.0
else if ( p % 2 == 1 ) e * pow(e, p - 1)
else { // p % 2 == 0
val x = pow(e, p/2)
x * x
}
// and this is what we stage
println("execute: " + execute(pow(3, 5)) + "\n")
val fun1 = function1[Double, Double](e => pow(e, 5))
println("fun1(3): " + fun1(3) + "\n")
val fun2 = function2[Double, Double, Double]((e1, e2) => pow(e1, 5) * pow(e2, 5))
println("fun2(3, 1): " + fun2(3, 1) + "\n")
}
}
...
In this example, it would be very attractive to partially evaluate pow
for the exponent p
. This would allow a faster execution when the power is fixed ahead of time. This is where the @staged
annotation comes into play: it marks next-stage values, namely the values that are not considered as constants in the partial evaluation (note that in staging, everything but the @staged
values are considered staging-time constants)
These @staged
-annotated values, with the help of the staging plugin runtime and transformation, gather the expressions they are involved in as an intermediate representation. From this intermediate representation, new, optimized code can be output, code which will execute significantly faster. The mechanism by which this has typically been done in Scala was using the scala-virtualized compiler with the lightweight modular staging framework. This plugin explores a different avenue of compiler-level staging by using the data representation transformation.
Running the following example yields:
$ st-scalac pow.scala -P:stagium:passive
$ st-scala stagium.examples.pow.Test
execute: 243.0
fun1(3): 243.0
fun2(3, 1): 243.0
$ st-scalac pow.scala
$ st-scala stagium.examples.pow.Test
Need to compile and run:
*********************************
{
val x0: Double = 3.0 * 3.0
val x1: Double = x0 * x0
val x2: Double = 3.0 * x1
x2: Double
} // end of code block of 3 instructions
*********************************
execute: null
Need to compile and run:
*********************************
{
val x6: Double => Double =
(a1: Double) => {
val x3: Double = a1 * a1
val x4: Double = x3 * x3
val x5: Double = a1 * x4
x5: Double
} // end of code block of 3 instructions
x6: Double => Double
} // end of function x6: Double => Double
*********************************
<function1 called>
fun1(3): 0.0
Need to compile and run:
*********************************
{
val x14: (Double, Double) => Double =
(a1: Double, a2: Double) => {
val x10: Double = a2 * a2
val x7: Double = a1 * a1
val x11: Double = x10 * x10
val x8: Double = x7 * x7
val x12: Double = a2 * x11
val x9: Double = a1 * x8
val x13: Double = x9 * x12
x13: Double
} // end of code block of 7 instructions
x14: (Double, Double) => Double
} // end of function x14: (Double, Double) => Double
*********************************
<function2 called>
fun2(3, 1): 0.0
In the first execution, the code was ran immediately, as the staging plugin was set to passive
mode. In the second run, the intermediate representation was extracted and new code was output for the partially evaluated pow
function.
NOTE: At the time this wiki page was written, the mechanism to compile and execute the staging-time code was not available, therefore the functions output 0.0
instead of the actual value. We plan to add this functionality.
How exactly did the code in function
get generated? We use the same mechanism in the lightweight modular staging framework, namely relaying the method calls on staged values to infix methods, which in turn created the intermediate representation.
To see how this works, it is useful to output the entire pow.scala
example:
$ cat pow.scala
package stagium
package examples
package pow
import scala.reflect.runtime.universe._
// A test for the power function
object Test {
def main(args: Array[String]): Unit = {
// this is a method with staged arguments
def pow(e: Double @staged, p: Int): Double @staged =
if ( p == 0 ) 1.0
else if ( p % 2 == 1 ) e * pow(e, p - 1)
else { // p % 2 == 0
val x = pow(e, p/2)
x * x
}
// and this is what we stage
println("execute: " + execute(pow(3, 5)) + "\n")
val fun1 = function1[Double, Double](e => pow(e, 5))
println("fun1(3): " + fun1(3) + "\n")
val fun2 = function2[Double, Double, Double]((e1, e2) => pow(e1, 5) * pow(e2, 5))
println("fun2(3, 1): " + fun2(3, 1) + "\n")
}
}
// This is the support object for staging
object __staged {
case class DoubleTimes(t1: Exp[Double], t2: Exp[Double]) extends Def[Double] {
override def toString = t1 + " * " + t2
}
def infix_*(r: Exp[Double], oth: Exp[Double]): Exp[Double] =
(r, oth) match {
case (Con(1), _) => oth
case (_, Con(1)) => r
case _ => DoubleTimes(r, oth)
}
}
As you can see, the example also defines infix_*
, the method to which multiplication(*
) gets redirected when the receiver is staged. The infix_*
body contains a simple optimization: when a number is multiplied by 1, it remains the same number. You can see this in the first two cases of the match expression: Con(1)
represents a constant of value 1, thus having an unit operand returns the other one.
The final translation of pow, when the staging plugin is active, is:
def pow(e: Exp[Double], p: Int): Exp[Double] =
if ( p == 0 ) Con[Double](1.0)
else if ( p % 2 == 1 ) infix_*(e, pow(e, p - 1))
else { // p % 2 == 0
val x = pow(e, p/2)
infix_*(x, x)
}
The type Exp
marks next-stage values, such as Con
stants, Arg
uments or Sym
bols. These are all defined in the value class plugin runtime. Symbols correspond to definitions, thus allowing common subexpression elimination and other standard optimizations.
With these definitions, executing pow
produces an intermediate representation of the staged expression. From there, the execute
and function
methods (which are rewritten to their implementations in staging.scala) trigger the stage
method which, in turn creates the code and the output you have seen.
We won't go into many details about the staging part, since it is not our contribution and has been very well described in the lightweight modular staging framework. Instead, the next chapter will focus on how the staging plugin transforms the annotations into the code that stages the pow
function.
Can you think of an extra case for infix_*
? (scroll down for the answer)
Hint: we only treated the case where one of the two operands is a constant.
What if the two operands are constants? We can add this case:
def infix_*(r: Exp[Double], oth: Exp[Double]): Exp[Double] =
(r, oth) match {
case (Con(x), Con(y) => Con(x * y)
case (Con(1), _) => oth
case (_, Con(1)) => r
case _ => DoubleTimes(r, oth)
}
And then, the earlier execution will result in:
$ st-scalac pow.scala
$ st-scala stagium.examples.pow.Test
Need to compile and run:
*********************************
{
243.0: Double
} // end of code block of 0 instructions
*********************************
...
Which means the first call, where both the base and the exponent were known, evaluated down to a constant. This is how staging can further optimize the staged programs.
NOTE: The staging plugin in the DRT virtual machine contains an error and will output:
$ st-scala stagium.examples.pow.Test
Need to compile and run:
*********************************
scala.NotImplementedError: an implementation is missing
at scala.Predef$.$qmark$qmark$qmark(Predef.scala:225)
at stagium.staging$$anonfun$3.apply(staging.scala:77)
...
To fix this, do:
$ cd ~/Workspace/staging-plugin/
$ git pull origin master
...
$ sbt package
[success] ...
See the next chapter to show the transformations that allow staging the pow
function.