Skip to content

Commit

Permalink
used fixed with beam search for decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Nov 27, 2024
1 parent 60e5810 commit b48c13b
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 44 deletions.
10 changes: 5 additions & 5 deletions latex/popl2025/popl.tex
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
From the lexical string, we build an automaton that represents all possible strings within a certain edit distance. Then, we proceed to construct a synthetic grammar, recognizing all strings in the intersection of the programming language and the edit ball. Finally, this grammar is reduced to a normal form and decoded with the help of a statistical model to produce a list of suggested repairs.

\begin{figure}[h!]
\includegraphics[width=\textwidth]{flow.pdf}\vspace{-1pt}
\includegraphics[width=\textwidth]{flow}\vspace{-1pt}
\caption{Simplified dataflow. Given a grammar and broken code fragment, we create an automaton generating the language of small edits, then construct a grammar representing the intersection of the two languages. This grammar can be converted to a finite automaton, determinized, then decoded to produce a list of repairs.}\label{fig:arch_simp}
\end{figure}

Expand Down Expand Up @@ -1199,7 +1199,7 @@ \subsection{Scoring and reranking enumerated trees by likelihood}\label{sec:rank

\begin{figure}[H]
\centering
\includegraphics[width=0.9\textwidth]{exampleDFA.pdf}
\includegraphics[width=0.9\textwidth]{exampleDFA}
\caption{Minimal DFA recognizing the language $L(\footnotesize{\texttt{NAME = NAME . NAME ( NUM : , NUM : )}}, 2) \cap \ell_\textsc{Python}$.}
\label{fig:exampleDFA}
\end{figure}
Expand Down Expand Up @@ -1659,14 +1659,14 @@ \section{Evaluation}
Next, we measure the precision at various ranking cutoffs and wall-clock timeouts. Our method attains the same precision as Seq2Parse and BIFI for 1-edit repairs at comparable latency, however Tidyparse takes longer to attain the same precision for 2- and 3-edit repairs. BIFI and Seq2Parse both have subsecond single-shot latency but are neural models trained on a much larger dataset.

\begin{figure}[h!]
% \resizebox{.19\textwidth}{!}{\input{bar_hillel_repair.tex}}
% \resizebox{.19\textwidth}{!}{\input{bar_hillel_repair}}
\resizebox{.24\textwidth}{!}{\input{bar_hillel_repair_1}}
\resizebox{.24\textwidth}{!}{\input{bar_hillel_repair_2}}
\resizebox{.24\textwidth}{!}{\input{bar_hillel_repair_3}}
\resizebox{.24\textwidth}{!}{\input{bar_hillel_repair_4}}
% \resizebox{.24\textwidth}{!}{\input{bar_hillel_repair_5}}
%\resizebox{.3\textwidth}{!}{\input{repair1_plot.tex}}
%\resizebox{.307\textwidth}{!}{\input{repair2_plot.tex}}
%\resizebox{.3\textwidth}{!}{\input{repair1_plot}}
%\resizebox{.307\textwidth}{!}{\input{repair2_plot}}
\caption{Human repair benchmarks. Note the y-axis across different edit distance plots has varying ranges. The red line indicates Seq2Parse and the orange line indicates BIFI's Precision@1 on the same repairs.}\label{fig:human}
\end{figure}

Expand Down
Binary file modified latex/thesis/Thesis.pdf
Binary file not shown.
6 changes: 3 additions & 3 deletions latex/thesis/content/Ch3_Deterministic_Repair.tex
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,13 @@ \section{The Bar-Hillel Construction}

\section{Parikh Refinements}

To identify these superfluous triples, we define an interval domain that soundly overapproximates the Parikh image, encoding the minimum and maximum number of terminals each nonterminal can generate. Since some intervals may be right-unbounded, we write $\mathbb{N}^*=\mathbb{N} \cup \{\infty\}$ to denote the upper bound, and $\Pi = \{[a, b] \in \mathbb{N} \times \mathbb{N}^* \mid a \leq b\}^{|\Sigma|}$ to denote the Parikh image of all terminals.
To identify superfluous $q, v, q'$ triples, we define an interval domain that soundly overapproximates the Parikh image, encoding the minimum and maximum number of terminals each nonterminal must and can generate, respectively. Since some intervals may be right-unbounded, we write $\mathbb{N}^*=\mathbb{N} \cup \{\infty\}$ to denote the upper bound, and $\Pi = \{[a, b] \in \mathbb{N} \times \mathbb{N}^* \mid a \leq b\}^{|\Sigma|}$ to denote the Parikh image of all terminals.

\begin{definition}[Parikh mapping of a nonterminal]\label{def:parikh}
Let $p: \Sigma^*\rightarrow\mathbb{N}^{|\Sigma|}$ be the Parikh operator~\cite{parikh1966context}, which counts the frequency of terminals in a string. We define the Parikh map, $\pi: V \rightarrow \Pi$, as a function returning the smallest interval such that $\forall \sigma: \Sigma^*, \forall v: V$, $v \Rightarrow^* \sigma \vdash p(\sigma) \in \pi(v)$.
\end{definition}

In other words, the Parikh mapping computes the greatest lower and least upper bound of the Parikh image over all strings in the language of a nonterminal. The infimum of a nonterminal's Parikh interval tells us how many of each terminal a nonterminal \textit{must} generate, and the supremum tells us how many it \textit{can} generate. Likewise, we define a similar relation over NFA state pairs:
The Parikh mapping computes the greatest lower and least upper bound of the Parikh image over all strings in the language of a nonterminal. The infimum of a nonterminal's Parikh interval tells us how many of each terminal a nonterminal \textit{must} generate, and the supremum tells us how many it \textit{can} generate. Likewise, we define a similar relation over NFA state pairs:

\begin{definition}[Parikh mapping of NFA states]
We define $\pi: Q\times Q \rightarrow \Pi$ as returning the smallest interval such that $\forall \sigma: \Sigma^*, \forall q, q': Q$, $q \overset{\sigma}{\Longrightarrow} q' \vdash p(\sigma) \in \pi(q, q')$.
Expand All @@ -260,7 +260,7 @@ \section{Parikh Refinements}
Given two Parikh intervals $\pi, \pi': \Pi$, we define the divergence between them as $\pi \parallel \pi' = \sum_{n=1}^{|\Sigma|} \min_{(i, i') \in \pi[n]\times \pi'[n]} |i - i'|$.
\end{definition}

Now, we know that if the Parikh divergence between two intervals is nonzero, those intervals must be incompatible as no two strings, one from each Parikh interval, can be transformed into the other with fewer than $\pi \parallel \pi'$ edits.
We know that if the Parikh divergence between two intervals is nonzero, those intervals must be incompatible as no two strings, one from each Parikh interval, can be transformed into the other with fewer than $\pi \parallel \pi'$ edits.

\begin{definition}[Parikh compatibility]
Let $q, q'$ be NFA states and $v$ be a CFG nonterminal. We call $\langle q, v, q'\rangle: Q\times V\times Q$ \textit{compatible} iff their divergence is zero, i.e., $v \lhd qq' \iff \big(\pi(v) \parallel \pi(q, q')\big) = 0$.
Expand Down
37 changes: 35 additions & 2 deletions latex/thesis/content/Ch4_Probabilistic_Repair.tex
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,39 @@ \chapter{\rm\bfseries Probabilistic Program Repair}

As we have seen, the problem of program repair is highly underdetermined. To resolve this ambiguity, we will use a probabilistic model to induce a distribution over the language of valid programs. This distribution will guide the repair process by assigning a likelihood to each possible repair. Then, taking the maximum over all possible repairs, we can find the most likely repair consistent with the constraints and the observed program.

Specifically, we will define an ordering over strings by their likelihood under the probabilistic model. We then define a repair as the most likely string consistent with the observed program and the grammar. We factorize the probability of a string as the product of the probability of each token in the string, conditioned on the previous tokens. This allows us to compute the likelihood of a string in a left-to-right fashion.
Specifically, we will define an ordering over strings by their likelihood under the probabilistic model. We then define a repair as the most likely string consistent with the observed program and the grammar. We factorize the probability of a string as the product of the probability of each token in the string, conditioned on its predecessors. This allows us to compute the joint probability in a left-to-right fashion.

This probabilistic model will generally admit programs that are locally probable, but globally inconsistent with the grammar. To enforce syntactic validity, we will use the probabilistic language model to ``steer'' a generative sampler through the automaton representing the repair language. This has two advantages: first, it allows us to sample from the repair language incrementally, and second, it ensures that subsequences with high probability are retrieved first, and all trajectories are syntactically valid.
This probabilistic model will generally admit programs that are locally probable, but globally inconsistent with the grammar. To enforce syntactic validity, we will use the probabilistic language model to ``steer'' a generative sampler through the automaton representing the repair language. This has two advantages: first, it allows us to sample from the repair language incrementally, and second, it ensures that subsequences with high probability are retrieved first, and all trajectories are syntactically valid.

We will consider two kinds of probabilistic models: a constrained Markov model and an unconstrained transformer-based neural network trained on program repair, then evaluate the performance of these models on a syntax repair benchmark consisting of pairwise program transformations. As we will show, the constrained Markov model is able to achieve state-of-the-art precision on blind prediction of the lexical sequence.

\begin{figure}[H]
\resizebox{.24\textwidth}{!}{\input{../popl2025/len_dist_tidy}}
\resizebox{.24\textwidth}{!}{\input{../popl2025/len_dist_bifi_all}}
\resizebox{.24\textwidth}{!}{\input{../popl2025/len_dist_s2p}}
\resizebox{.24\textwidth}{!}{\input{../popl2025/len_dist_bifi}}
\caption{Tidyparse, Seq2Parse and BIFI repair precision across length and edits.}
\end{figure}

If we give it an equivalent number of samples, the constrained Markov model attains an even wider margin.

\begin{figure}[H]
\resizebox{.24\textwidth}{!}{\input{../popl2025/len_dist_tidy}}
\resizebox{.24\textwidth}{!}{\input{../popl2025/len_dist_bifi_all}}
\resizebox{.24\textwidth}{!}{\input{../popl2025/len_dist_tidy200}}
\resizebox{.24\textwidth}{!}{\input{../popl2025/len_dist_tidy20k}}
\caption{Tidyparse, Seq2Parse and BIFI repair precision across length and edits.}
\end{figure}

Now, we measure latency.

\begin{figure}[H]
% \resizebox{.19\textwidth}{!}{\input{bar_hillel_repair.tex}}
\resizebox{.24\textwidth}{!}{\input{../popl2025/bar_hillel_repair_1}}
\resizebox{.24\textwidth}{!}{\input{../popl2025/bar_hillel_repair_2}}
\resizebox{.24\textwidth}{!}{\input{../popl2025/bar_hillel_repair_3}}
% \resizebox{.24\textwidth}{!}{\input{bar_hillel_repair_5}}
%\resizebox{.3\textwidth}{!}{\input{repair1_plot.tex}}
%\resizebox{.307\textwidth}{!}{\input{repair2_plot.tex}}
\caption{Latency benchmarks. Note the varying axis ranges. The red line marks Seq2Parse and the orange line marks BIFI's Precision@1 on the same repairs.}\label{fig:human}
\end{figure}
111 changes: 79 additions & 32 deletions src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.markovian.mcmc.MarkovChain
import dk.brics.automaton.Automaton.*
import dk.brics.automaton.Transition
import java.util.concurrent.*
import java.util.PriorityQueue
import kotlin.random.Random
import kotlin.time.*

Expand Down Expand Up @@ -49,10 +49,13 @@ fun JAutomaton<String, Double>.toDot(processed: MutableSet<Any> = mutableSetOf()
* previous n-1 transitions, i.e., q' ~ argmax_{q'} P(q' | q_{t-1}, ..., q_{t-n+1})
*/

data class FSATrajectory(val traj: List<Σᐩ?>, val lastState: BState, val score: Double) {
data class FSATrajectory(val traj: List<Σᐩ?>, val lastState: BState, val score: Double, val id: Int = traj.hashCode()) {
val isComplete: Boolean = lastState.isAccept
val tokens by lazy { traj.reversed().filterNotNull() }
fun append(tok: Σᐩ?, state: BState, score: Double) =
FSATrajectory(listOf(tok) + traj, state, score, id * 31 + tok.hashCode())
override fun toString() = tokens.joinToString(" ")
override fun equals(other: Any?): Boolean = other is FSATrajectory && id == other.id
}

fun BAutomaton.min(): BAutomaton = minimize(this)
Expand All @@ -65,14 +68,15 @@ fun PTree.toDFA(
) =
measureTimedValue {
BAutomaton.setMinimization(MINIMIZE_BRZOZOWSKI)
val period = 5
var i = 0
var j = 0
propagator(
both = { a, b -> if (a == null) b else if (b == null) a
// Only periodically minimize the automata during construction
else if (i++ % 13 == 0) a.concatenate(b).min() else a.concatenate(b) },
else if (i++ % period == 0) a.concatenate(b).min() else a.concatenate(b) },
either = { a, b -> if (a == null) b else if (b == null) a
else if (j++ % 13 == 0) a.union(b).min() else a.union(b) },
else if (j++ % period == 0) a.union(b).min() else a.union(b) },
unit = { a -> if ("ε" in a.root) null else unitRule(a.root) }
)
}.also { println("Took ${it.duration} to build FSA") }.value
Expand All @@ -98,57 +102,100 @@ fun BAutomaton.decodeDFA(
callback: (Σᐩ) -> Unit = {},
topK: Int = 10_000_000, // Total number of top-K results to return
timeout: Duration = Duration.INFINITE,
parallelize: Boolean = false
): List<Σᐩ> {
val startTime = TimeSource.Monotonic.markNow()
val load = 100_000
val fullTrajectories = PriorityBlockingQueue<FSATrajectory>(load, compareBy { it.score / it.traj.size })
val partTrajectories = Array(if(parallelize) NUM_CORES else 1) {
PriorityBlockingQueue<FSATrajectory>(load, compareBy { it.score / it.traj.size })
val fullTrajectories = PriorityQueue<FSATrajectory>(load, compareBy { it.score / it.traj.size })
val partTrajectories =
PriorityQueue<FSATrajectory>(load, compareBy { it.score / it.traj.size })
.apply { add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0)) }
}

fun task(id: Int = 0) {
var i = 0
while (
fullTrajectories.size < topK &&
partTrajectories.any { it.size > 0 } &&
partTrajectories.size > 0 &&
startTime.elapsedNow() < timeout
) {
if (partTrajectories[id].isEmpty()) continue
// Checks for balanced distribution of work across cores
// if (i++ % 9999 == 0) println("Trajectories[$id]: ${partTrajectories.map {it.size}}")
val partTraj = partTrajectories[id].remove()
val partTraj = partTrajectories.poll()
val lastToks = partTraj.traj.take(mc.memory - 1).reversed()
partTraj.lastState.transitions.forEach { next: Transition ->
(next.min..next.max).forEach { tok ->
partTraj.lastState.transitions.flatMap { next ->
(next.min..next.max).map { tok ->
val decTok = dec[tok]
val nextToks = lastToks + decTok
val nextScore = partTraj.score + mc.scoreChunk(nextToks)
val traj = FSATrajectory(listOf(decTok) + partTraj.traj, next.dest, nextScore)
val bin = if (parallelize) Random(traj.score.hashCode()).nextInt(NUM_CORES) else 0
if (!traj.isComplete) partTrajectories[bin].add(traj)
val nextScore = partTraj.score + mc.scoreChunk(lastToks + decTok)

Triple(next, decTok, nextScore)
}
}
// .sortedBy { (_, _, nextScore) -> -nextScore }.take(100)
.forEach { (next: Transition, decTok: String?, nextScore: Double) ->
val traj = partTraj.append(decTok, next.dest, nextScore)
if (!traj.isComplete) { partTrajectories.add(traj) }
else {
fullTrajectories.add(traj)
callback(traj.toString())
if (traj.lastState.transitions.isNotEmpty())
partTrajectories[bin].add(traj)
fullTrajectories.add(traj.also { callback(it.toString()) })
if (traj.lastState.transitions.isNotEmpty()) partTrajectories.add(traj)
}
}
}
}
}

if (parallelize) (0..<NUM_CORES).toList().parallelStream().forEach { task(it) } else task(0)

val deduped = fullTrajectories.map { it.toString() }.distinct().toList()
// .map { it.toString() to mc.score(it.tokens) }
// .distinct().toList().sortedBy { it.second }.map { it.first }

// println("Top 10 trajectories:")
// fullTrajectories.take(10).forEach { println(it.score.toString().take(5) + ": $it") }
println("Took ${startTime.elapsedNow()} to decode ${deduped.size} trajectories")
println("Took ${startTime.elapsedNow()} to decode ${deduped.size} trajectories, with ${partTrajectories.size} in queue")

return deduped
}

fun BAutomaton.decodeDFAWithBeamSearch(
mc: MarkovChain<Σᐩ>,
dec: Map<Char, Σᐩ>, // Maps unicode characters back to strings
callback: (Σᐩ) -> Unit = {},
topK: Int = 10_000_000, // Total number of top-K results to return
timeout: Duration = Duration.INFINITE,
beamWidth: Int = 100_000, // Maximum number of trajectories to keep at each step
): List<Σᐩ> {
val startTime = TimeSource.Monotonic.markNow()
val fullTrajectories = PriorityQueue<FSATrajectory>(compareBy { it.score / it.traj.size }) // Max-heap for full trajectories
val beam = PriorityQueue<FSATrajectory>(compareBy { it.score / it.traj.size }) // Beam for partial trajectories

beam.add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0))

while (
fullTrajectories.size < topK &&
beam.isNotEmpty() &&
startTime.elapsedNow() < timeout
) {
val nextBeam = PriorityQueue<FSATrajectory>(compareBy { it.score / it.traj.size })

while (beam.isNotEmpty() && startTime.elapsedNow() < timeout) {
val partTraj = beam.poll()
val lastToks = partTraj.traj.take(mc.memory - 1).reversed()

partTraj.lastState.transitions.flatMap { next ->
(next.min..next.max).map { tok ->
val decTok = dec[tok]
val nextScore = partTraj.score + mc.scoreChunk(lastToks + decTok)
partTraj.append(decTok, next.dest, nextScore)
}
}.forEach { traj ->
if (traj.isComplete) {
if (traj.lastState.transitions.isNotEmpty()) nextBeam.add(traj)
fullTrajectories.add(traj)
callback(traj.toString())
} else {
nextBeam.add(traj)
}
}
}

beam.clear()
beam.addAll(nextBeam.take(beamWidth))
}

val deduped = fullTrajectories.map { it.toString() }.distinct().toList()

println("Took ${startTime.elapsedNow()} to decode ${deduped.size} trajectories, with ${beam.size} in queue")
return deduped
}

Expand Down
Loading

0 comments on commit b48c13b

Please sign in to comment.