diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index cbcccb11f14ae..6b9aa5071e1d0 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -288,7 +288,7 @@ class ReplSuite extends SparkFunSuite { |import org.apache.spark.sql.Encoder |import org.apache.spark.sql.expressions.Aggregator |import org.apache.spark.sql.TypedColumn - |val simpleSum = new Aggregator[Int, Int, Int] with Serializable { + |val simpleSum = new Aggregator[Int, Int, Int] { | def zero: Int = 0 // The initial value. | def reduce(b: Int, a: Int) = b + a // Add an element to the running total | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values. @@ -347,7 +347,7 @@ class ReplSuite extends SparkFunSuite { |import org.apache.spark.sql.expressions.Aggregator |import org.apache.spark.sql.TypedColumn |/** An `Aggregator` that adds up any numeric type returned by the given function. */ - |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { + |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] { | val numeric = implicitly[Numeric[N]] | override def zero: N = numeric.zero | override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 6bee880640ced..f148a6df47607 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -249,10 +249,32 @@ class ReplSuite extends SparkFunSuite { // We need to use local-cluster to test this case. val output = runInterpreter("local-cluster[1,1,1024]", """ - |val sqlContext = new org.apache.spark.sql.SQLContext(sc) - |import sqlContext.implicits._ |case class TestCaseClass(value: Int) |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect() + | + |// Test Dataset Serialization in the REPL + |Seq(TestCaseClass(1)).toDS().collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + test("Datasets and encoders") { + val output = runInterpreter("local", + """ + |import org.apache.spark.sql.functions._ + |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.expressions.Aggregator + |import org.apache.spark.sql.TypedColumn + |val simpleSum = new Aggregator[Int, Int, Int] { + | def zero: Int = 0 // The initial value. + | def reduce(b: Int, a: Int) = b + a // Add an element to the running total + | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values. + | def finish(b: Int) = b // Return the final result. + |}.toColumn + | + |val ds = Seq(1, 2, 3, 4).toDS() + |ds.select(simpleSum).collect """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -295,6 +317,31 @@ class ReplSuite extends SparkFunSuite { } } + test("Datasets agg type-inference") { + val output = runInterpreter("local", + """ + |import org.apache.spark.sql.functions._ + |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.expressions.Aggregator + |import org.apache.spark.sql.TypedColumn + |/** An `Aggregator` that adds up any numeric type returned by the given function. */ + |class SumOf[I, N : Numeric](f: I => N) extends + | org.apache.spark.sql.expressions.Aggregator[I, N, N] { + | val numeric = implicitly[Numeric[N]] + | override def zero: N = numeric.zero + | override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) + | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2) + | override def finish(reduction: N): N = reduction + |} + | + |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn + |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS() + |ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + test("collecting objects of class defined in repl") { val output = runInterpreter("local[2]", """ @@ -317,4 +364,21 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("Exception", output) assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output) } + + test("line wrapper only initialized once when used as encoder outer scope") { + val output = runInterpreter("local", + """ + |val fileName = "repl-test-" + System.currentTimeMillis + |val tmpDir = System.getProperty("java.io.tmpdir") + |val file = new java.io.File(tmpDir, fileName) + |def createFile(): Unit = file.createNewFile() + | + |createFile();case class TestCaseClass(value: Int) + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect() + | + |file.delete() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ccc65b4e5256e..ebb3a931da09f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -571,7 +571,7 @@ class Analyzer( if n.outerPointer.isEmpty && n.cls.isMemberClass && !Modifier.isStatic(n.cls.getModifiers) => - val outer = OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName) + val outer = OuterScopes.getOuterScope(n.cls) if (outer == null) { throw new AnalysisException( s"Unable to generate an encoder for inner class `${n.cls.getName}` without " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala index a753b187bcd32..c047e96463544 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala @@ -21,6 +21,8 @@ import java.util.concurrent.ConcurrentMap import com.google.common.collect.MapMaker +import org.apache.spark.util.Utils + object OuterScopes { @transient lazy val outerScopes: ConcurrentMap[String, AnyRef] = @@ -28,7 +30,7 @@ object OuterScopes { /** * Adds a new outer scope to this context that can be used when instantiating an `inner class` - * during deserialialization. Inner classes are created when a case class is defined in the + * during deserialization. Inner classes are created when a case class is defined in the * Spark REPL and registering the outer scope that this class was defined in allows us to create * new instances on the spark executors. In normal use, users should not need to call this * function. @@ -39,4 +41,47 @@ object OuterScopes { def addOuterScope(outer: AnyRef): Unit = { outerScopes.putIfAbsent(outer.getClass.getName, outer) } + + def getOuterScope(innerCls: Class[_]): AnyRef = { + assert(innerCls.isMemberClass) + val outerClassName = innerCls.getDeclaringClass.getName + val outer = outerScopes.get(outerClassName) + if (outer == null) { + outerClassName match { + // If the outer class is generated by REPL, users don't need to register it as it has + // only one instance and there is a way to retrieve it: get the `$read` object, call the + // `INSTANCE()` method to get the single instance of class `$read`. Then call `$iw()` + // method multiply times to get the single instance of the inner most `$iw` class. + case REPLClass(baseClassName) => + val objClass = Utils.classForName(baseClassName + "$") + val objInstance = objClass.getField("MODULE$").get(null) + val baseInstance = objClass.getMethod("INSTANCE").invoke(objInstance) + val baseClass = Utils.classForName(baseClassName) + + var getter = iwGetter(baseClass) + var obj = baseInstance + while (getter != null) { + obj = getter.invoke(obj) + getter = iwGetter(getter.getReturnType) + } + + outerScopes.putIfAbsent(outerClassName, obj) + obj + case _ => null + } + } else { + outer + } + } + + private def iwGetter(cls: Class[_]) = { + try { + cls.getMethod("$iw") + } catch { + case _: NoSuchMethodException => null + } + } + + // The format of REPL generated wrapper class's name, e.g. `$line12.$read$$iw$$iw` + private[this] val REPLClass = """^(\$line(?:\d+)\.\$read)(?:\$\$iw)+$""".r }