diff --git a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt index 844030c1..631271e3 100644 --- a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt +++ b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt @@ -6,6 +6,7 @@ import ai.hypergraph.markovian.mcmc.MarkovChain import dk.brics.automaton.Automaton.* import dk.brics.automaton.Transition import java.util.PriorityQueue +import java.util.concurrent.PriorityBlockingQueue import kotlin.random.Random import kotlin.time.* @@ -98,99 +99,42 @@ fun BAutomaton.decodeDFA( mc: MarkovChain<Σᐩ>, // BAutomata uses a Unicode alphabet, and the Markov Chain recognizes a // string-based alphabet, so we need a way to translate between the two - dec: Map, // Maps unicode characters back to strings because BAutomata uses Unicode - callback: (Σᐩ) -> Unit = {}, - topK: Int = 10_000_000, // Total number of top-K results to return - timeout: Duration = Duration.INFINITE, -): List<Σᐩ> { - val startTime = TimeSource.Monotonic.markNow() - val load = 100_000 - val fullTrajectories = PriorityQueue(load, compareBy { it.score / it.traj.size }) - val partTrajectories = - PriorityQueue(load, compareBy { it.score / it.traj.size }) - .apply { add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0)) } - - while ( - fullTrajectories.size < topK && - partTrajectories.size > 0 && - startTime.elapsedNow() < timeout - ) { - val partTraj = partTrajectories.poll() - val lastToks = partTraj.traj.take(mc.memory - 1).reversed() - partTraj.lastState.transitions.flatMap { next -> - (next.min..next.max).map { tok -> - val decTok = dec[tok] - val nextScore = partTraj.score + mc.scoreChunk(lastToks + decTok) - - Triple(next, decTok, nextScore) - } - } -// .sortedBy { (_, _, nextScore) -> -nextScore }.take(100) - .forEach { (next: Transition, decTok: String?, nextScore: Double) -> - val traj = partTraj.append(decTok, next.dest, nextScore) - if (!traj.isComplete) { partTrajectories.add(traj) } - else { - fullTrajectories.add(traj.also { callback(it.toString()) }) - if (traj.lastState.transitions.isNotEmpty()) partTrajectories.add(traj) - } - } - } - - val deduped = fullTrajectories.map { it.toString() }.distinct().toList() -// .map { it.toString() to mc.score(it.tokens) } -// .distinct().toList().sortedBy { it.second }.map { it.first } - -// println("Top 10 trajectories:") -// fullTrajectories.take(10).forEach { println(it.score.toString().take(5) + ": $it") } - println("Took ${startTime.elapsedNow()} to decode ${deduped.size} trajectories, with ${partTrajectories.size} in queue") - - return deduped -} - -fun BAutomaton.decodeDFAWithBeamSearch( - mc: MarkovChain<Σᐩ>, dec: Map, // Maps unicode characters back to strings callback: (Σᐩ) -> Unit = {}, - topK: Int = 10_000_000, // Total number of top-K results to return timeout: Duration = Duration.INFINITE, - beamWidth: Int = 100_000, // Maximum number of trajectories to keep at each step + beamWidth: Long = 1_000_000L, // Maximum number of trajectories to keep at each step ): List<Σᐩ> { val startTime = TimeSource.Monotonic.markNow() - val fullTrajectories = PriorityQueue(compareBy { it.score / it.traj.size }) // Max-heap for full trajectories + val fullTrajectories = PriorityBlockingQueue(10000, compareBy { it.score / it.traj.size }) // Max-heap for full trajectories val beam = PriorityQueue(compareBy { it.score / it.traj.size }) // Beam for partial trajectories beam.add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0)) while ( - fullTrajectories.size < topK && + fullTrajectories.size < beamWidth && beam.isNotEmpty() && startTime.elapsedNow() < timeout ) { - val nextBeam = PriorityQueue(compareBy { it.score / it.traj.size }) - - while (beam.isNotEmpty() && startTime.elapsedNow() < timeout) { - val partTraj = beam.poll() + val nextBeam = beam.parallelStream().flatMap { partTraj -> val lastToks = partTraj.traj.take(mc.memory - 1).reversed() - partTraj.lastState.transitions.flatMap { next -> (next.min..next.max).map { tok -> val decTok = dec[tok] val nextScore = partTraj.score + mc.scoreChunk(lastToks + decTok) partTraj.append(decTok, next.dest, nextScore) } - }.forEach { traj -> + }.flatMap { traj -> if (traj.isComplete) { - if (traj.lastState.transitions.isNotEmpty()) nextBeam.add(traj) fullTrajectories.add(traj) callback(traj.toString()) - } else { - nextBeam.add(traj) - } - } - } + if (traj.lastState.transitions.isNotEmpty()) listOf(traj to traj.score) else emptyList() + } else { listOf(traj to traj.score) } + }.stream() + }.sorted(compareBy { it.second / it.first.traj.size }) + .limit(beamWidth).map { it.first }.toList() beam.clear() - beam.addAll(nextBeam.take(beamWidth)) + beam.addAll(nextBeam) } val deduped = fullTrajectories.map { it.toString() }.distinct().toList()