Skip to content

Commit

Permalink
simplify DFA decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Dec 13, 2024
1 parent 3384ffa commit 333c135
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,16 @@ fun JAutomaton<String, Double>.toDot(processed: MutableSet<Any> = mutableSetOf()
* previous n-1 transitions, i.e., q' ~ argmax_{q'} P(q' | q_{t-1}, ..., q_{t-n+1})
*/

data class FSATrajectory(val traj: List<Σᐩ?>, val lastState: BState, val score: Double, val id: Int = traj.hashCode()) {
data class FSATrajectory(val traj: List<Σᐩ?>, val lastState: BState,
val score: Double, val id: Int = traj.hashCode()): Comparable<FSATrajectory> {
val isComplete: Boolean = lastState.isAccept
val tokens by lazy { traj.reversed().filterNotNull() }
val lenNormedScore = score / traj.size
fun append(tok: Σᐩ?, state: BState, score: Double) =
FSATrajectory(listOf(tok) + traj, state, score, id * 31 + tok.hashCode())
override fun toString() = tokens.joinToString(" ")
override fun equals(other: Any?): Boolean = other is FSATrajectory && id == other.id
override fun equals(other: Any?): Boolean = other is FSATrajectory && lenNormedScore == other.lenNormedScore
override fun compareTo(other: FSATrajectory): Int = lenNormedScore.compareTo(other.lenNormedScore)
}

fun BAutomaton.min(): BAutomaton = minimize(this)
Expand Down Expand Up @@ -105,8 +108,8 @@ fun BAutomaton.decodeDFA(
beamWidth: Long = 1_000_000L, // Maximum number of trajectories to keep at each step
): List<Σᐩ> {
val startTime = TimeSource.Monotonic.markNow()
val fullTrajectories = PriorityBlockingQueue<FSATrajectory>(10000, compareBy { it.score / it.traj.size }) // Max-heap for full trajectories
val beam = PriorityQueue<FSATrajectory>(compareBy { it.score / it.traj.size }) // Beam for partial trajectories
val fullTrajectories = PriorityBlockingQueue<FSATrajectory>(10000) // Max-heap for full trajectories
val beam = PriorityQueue<FSATrajectory>() // Beam for partial trajectories

beam.add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0))

Expand All @@ -127,17 +130,16 @@ fun BAutomaton.decodeDFA(
if (traj.isComplete) {
fullTrajectories.add(traj)
callback(traj.toString())
if (traj.lastState.transitions.isNotEmpty()) listOf(traj to traj.score) else emptyList()
} else { listOf(traj to traj.score) }
if (traj.lastState.transitions.isNotEmpty()) listOf(traj) else emptyList()
} else { listOf(traj) }
}.stream()
}.sorted(compareBy { it.second / it.first.traj.size })
.limit(beamWidth).map { it.first }.toList()
}.sorted().limit(beamWidth).toList()

beam.clear()
beam.addAll(nextBeam)
}

val deduped = fullTrajectories.map { it.toString() }.distinct().toList()
val deduped = fullTrajectories.distinct().map { it.toString() }.toList()

println("Took ${startTime.elapsedNow()} to decode ${deduped.size} trajectories, with ${beam.size} in queue")
return deduped
Expand Down

0 comments on commit 333c135

Please sign in to comment.