From 2099e5d47bb2a8ac3640a5a3a1b57e89236af38d Mon Sep 17 00:00:00 2001 From: Nils Homer Date: Wed, 21 Aug 2019 22:02:35 -0700 Subject: [PATCH] Add a test for ScatterGather showing failed deferred access with map --- .../scala/dagr/tasks/ScatterGatherTests.scala | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/tasks/src/test/scala/dagr/tasks/ScatterGatherTests.scala b/tasks/src/test/scala/dagr/tasks/ScatterGatherTests.scala index d35db95b..9604bda6 100644 --- a/tasks/src/test/scala/dagr/tasks/ScatterGatherTests.scala +++ b/tasks/src/test/scala/dagr/tasks/ScatterGatherTests.scala @@ -395,4 +395,88 @@ class ScatterGatherTests extends UnitSpec with LazyLogging with BeforeAndAfterAl (wordLength.toInt, count.toInt) }.sortBy(_._1) should contain theSameElementsInOrderAs Seq((3, 9), (4, 3), (5, 3)) } + + private case class SumLines(input: Path) extends SimpleInJvmTask { + private var _result: Option[Int] = None + def result: Int = this._result.get + def run(): Unit = _result = Some(Io.readLines(input).map(_.toInt).sum) + } + + it should "map twice in a row" in { + val lines = Seq("one", "one two", "one two three", "one two three four", "one two three four five") + + // setup the input and output + val input = tmp() + val resultPath = tmp() + Io.writeLines(input, lines) + + val pipeline = new Pipeline() { + name = "Pipeline" + + override def build(): Unit = { + // the initial scatter: scatters across lines + val scatter: Scatter[Path] = Scatter(SplitByLine(input=input)) + scatter + .map { path: Path => CountWords(input=path, output=tmp()) } + .map { countWords: CountWords => SumLines(input=countWords.output) } + .gather { sumLines: Seq[SumLines] => WriteNumber(sumLines.map(_.result).sum, resultPath) } + root ==> scatter + } + } + + val taskManager = buildTaskManager + taskManager.addTask(pipeline) + val taskMap = taskManager.runToCompletion(true) + + taskMap.foreach { case (_, info) => + info.status shouldBe TaskStatus.SUCCEEDED + } + + val outputLines = Io.readLines(resultPath).toList + outputLines should have size 1 + outputLines.head.toInt shouldBe 15 + } + + /** Splits a file into one file per line. */ + private case class IdentityPartitioner[Result](inputs: Seq[Result]) extends SimpleInJvmTask with Partitioner[Result] { + name = s"Partition to $inputs" + var partitions: Option[Seq[Result]] = None + def run(): Unit = this.partitions = Some(inputs) + } + + private case class AddOne(input: Int) extends SimpleInJvmTask { + name = s"AddOne to $input" + var _result: Option[Int] = None + def result: Int = _result.getOrElse { + throw new IllegalArgumentException(s"Accessing result before $name has completed") + } + override def run(): Unit = this._result = Some(input + 1) + } + + it should "map twice in a row requiring deferred access" in { + val output = tmp() + val pipeline = new Pipeline() { + name = "Pipeline" + override def build(): Unit = { + val scatter: Scatter[Int] = Scatter(IdentityPartitioner(Seq(1, 2, 3, 4, 5))) + scatter + .map { int => AddOne(int) } + .map { addOne => AddOne(addOne.result) } // FIXME: addOne.result should only be accessed **after** addOne has completed + .gather { addOnes => WriteNumber(addOnes.map(_.result).sum, output) } + root ==> scatter + } + } + + val taskManager = buildTaskManager + taskManager.addTask(pipeline) + val taskMap = taskManager.runToCompletion(true) + + taskMap.foreach { case (_, info) => + info.status shouldBe TaskStatus.SUCCEEDED + } + + val lines = Io.readLines(output).toList + lines should have size 1 + lines.head.toInt shouldBe 25 + } }