Skip to content
This repository has been archived by the owner on Jun 20, 2024. It is now read-only.

Add a test for ScatterGather showing failed deferred access with map #361

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions tasks/src/test/scala/dagr/tasks/ScatterGatherTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -395,4 +395,88 @@ class ScatterGatherTests extends UnitSpec with LazyLogging with BeforeAndAfterAl
(wordLength.toInt, count.toInt)
}.sortBy(_._1) should contain theSameElementsInOrderAs Seq((3, 9), (4, 3), (5, 3))
}

private case class SumLines(input: Path) extends SimpleInJvmTask {
private var _result: Option[Int] = None
def result: Int = this._result.get
def run(): Unit = _result = Some(Io.readLines(input).map(_.toInt).sum)
}

it should "map twice in a row" in {
val lines = Seq("one", "one two", "one two three", "one two three four", "one two three four five")

// setup the input and output
val input = tmp()
val resultPath = tmp()
Io.writeLines(input, lines)

val pipeline = new Pipeline() {
name = "Pipeline"

override def build(): Unit = {
// the initial scatter: scatters across lines
val scatter: Scatter[Path] = Scatter(SplitByLine(input=input))
scatter
.map { path: Path => CountWords(input=path, output=tmp()) }
.map { countWords: CountWords => SumLines(input=countWords.output) }
.gather { sumLines: Seq[SumLines] => WriteNumber(sumLines.map(_.result).sum, resultPath) }
root ==> scatter
}
}

val taskManager = buildTaskManager
taskManager.addTask(pipeline)
val taskMap = taskManager.runToCompletion(true)

taskMap.foreach { case (_, info) =>
info.status shouldBe TaskStatus.SUCCEEDED
}

val outputLines = Io.readLines(resultPath).toList
outputLines should have size 1
outputLines.head.toInt shouldBe 15
}

/** Splits a file into one file per line. */
private case class IdentityPartitioner[Result](inputs: Seq[Result]) extends SimpleInJvmTask with Partitioner[Result] {
name = s"Partition to $inputs"
var partitions: Option[Seq[Result]] = None
def run(): Unit = this.partitions = Some(inputs)
}

private case class AddOne(input: Int) extends SimpleInJvmTask {
name = s"AddOne to $input"
var _result: Option[Int] = None
def result: Int = _result.getOrElse {
throw new IllegalArgumentException(s"Accessing result before $name has completed")
}
override def run(): Unit = this._result = Some(input + 1)
}

it should "map twice in a row requiring deferred access" in {
val output = tmp()
val pipeline = new Pipeline() {
name = "Pipeline"
override def build(): Unit = {
val scatter: Scatter[Int] = Scatter(IdentityPartitioner(Seq(1, 2, 3, 4, 5)))
scatter
.map { int => AddOne(int) }
.map { addOne => AddOne(addOne.result) } // FIXME: addOne.result should only be accessed **after** addOne has completed
.gather { addOnes => WriteNumber(addOnes.map(_.result).sum, output) }
root ==> scatter
}
}

val taskManager = buildTaskManager
taskManager.addTask(pipeline)
val taskMap = taskManager.runToCompletion(true)

taskMap.foreach { case (_, info) =>
info.status shouldBe TaskStatus.SUCCEEDED
}

val lines = Io.readLines(output).toList
lines should have size 1
lines.head.toInt shouldBe 25
}
}