From 333c1350363eb8f7e22372d94af4b239f7e140a9 Mon Sep 17 00:00:00 2001 From: breandan Date: Fri, 13 Dec 2024 16:15:23 -0500 Subject: [PATCH] simplify DFA decoder --- .../hypergraph/kaliningraph/automata/JFSA.kt | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt index 631271e3..53f8772e 100644 --- a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt +++ b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt @@ -50,13 +50,16 @@ fun JAutomaton.toDot(processed: MutableSet = mutableSetOf() * previous n-1 transitions, i.e., q' ~ argmax_{q'} P(q' | q_{t-1}, ..., q_{t-n+1}) */ -data class FSATrajectory(val traj: List<Σᐩ?>, val lastState: BState, val score: Double, val id: Int = traj.hashCode()) { +data class FSATrajectory(val traj: List<Σᐩ?>, val lastState: BState, + val score: Double, val id: Int = traj.hashCode()): Comparable { val isComplete: Boolean = lastState.isAccept val tokens by lazy { traj.reversed().filterNotNull() } + val lenNormedScore = score / traj.size fun append(tok: Σᐩ?, state: BState, score: Double) = FSATrajectory(listOf(tok) + traj, state, score, id * 31 + tok.hashCode()) override fun toString() = tokens.joinToString(" ") - override fun equals(other: Any?): Boolean = other is FSATrajectory && id == other.id + override fun equals(other: Any?): Boolean = other is FSATrajectory && lenNormedScore == other.lenNormedScore + override fun compareTo(other: FSATrajectory): Int = lenNormedScore.compareTo(other.lenNormedScore) } fun BAutomaton.min(): BAutomaton = minimize(this) @@ -105,8 +108,8 @@ fun BAutomaton.decodeDFA( beamWidth: Long = 1_000_000L, // Maximum number of trajectories to keep at each step ): List<Σᐩ> { val startTime = TimeSource.Monotonic.markNow() - val fullTrajectories = PriorityBlockingQueue(10000, compareBy { it.score / it.traj.size }) // Max-heap for full trajectories - val beam = PriorityQueue(compareBy { it.score / it.traj.size }) // Beam for partial trajectories + val fullTrajectories = PriorityBlockingQueue(10000) // Max-heap for full trajectories + val beam = PriorityQueue() // Beam for partial trajectories beam.add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0)) @@ -127,17 +130,16 @@ fun BAutomaton.decodeDFA( if (traj.isComplete) { fullTrajectories.add(traj) callback(traj.toString()) - if (traj.lastState.transitions.isNotEmpty()) listOf(traj to traj.score) else emptyList() - } else { listOf(traj to traj.score) } + if (traj.lastState.transitions.isNotEmpty()) listOf(traj) else emptyList() + } else { listOf(traj) } }.stream() - }.sorted(compareBy { it.second / it.first.traj.size }) - .limit(beamWidth).map { it.first }.toList() + }.sorted().limit(beamWidth).toList() beam.clear() beam.addAll(nextBeam) } - val deduped = fullTrajectories.map { it.toString() }.distinct().toList() + val deduped = fullTrajectories.distinct().map { it.toString() }.toList() println("Took ${startTime.elapsedNow()} to decode ${deduped.size} trajectories, with ${beam.size} in queue") return deduped