diff --git a/latex/popl2025/popl.tex b/latex/popl2025/popl.tex index 5ea4c11d..ce5b3e6b 100644 --- a/latex/popl2025/popl.tex +++ b/latex/popl2025/popl.tex @@ -150,7 +150,7 @@ From the lexical string, we build an automaton that represents all possible strings within a certain edit distance. Then, we proceed to construct a synthetic grammar, recognizing all strings in the intersection of the programming language and the edit ball. Finally, this grammar is reduced to a normal form and decoded with the help of a statistical model to produce a list of suggested repairs. \begin{figure}[h!] - \includegraphics[width=\textwidth]{flow.pdf}\vspace{-1pt} + \includegraphics[width=\textwidth]{flow}\vspace{-1pt} \caption{Simplified dataflow. Given a grammar and broken code fragment, we create an automaton generating the language of small edits, then construct a grammar representing the intersection of the two languages. This grammar can be converted to a finite automaton, determinized, then decoded to produce a list of repairs.}\label{fig:arch_simp} \end{figure} @@ -1199,7 +1199,7 @@ \subsection{Scoring and reranking enumerated trees by likelihood}\label{sec:rank \begin{figure}[H] \centering - \includegraphics[width=0.9\textwidth]{exampleDFA.pdf} + \includegraphics[width=0.9\textwidth]{exampleDFA} \caption{Minimal DFA recognizing the language $L(\footnotesize{\texttt{NAME = NAME . NAME ( NUM : , NUM : )}}, 2) \cap \ell_\textsc{Python}$.} \label{fig:exampleDFA} \end{figure} @@ -1659,14 +1659,14 @@ \section{Evaluation} Next, we measure the precision at various ranking cutoffs and wall-clock timeouts. Our method attains the same precision as Seq2Parse and BIFI for 1-edit repairs at comparable latency, however Tidyparse takes longer to attain the same precision for 2- and 3-edit repairs. BIFI and Seq2Parse both have subsecond single-shot latency but are neural models trained on a much larger dataset. \begin{figure}[h!] -% \resizebox{.19\textwidth}{!}{\input{bar_hillel_repair.tex}} +% \resizebox{.19\textwidth}{!}{\input{bar_hillel_repair}} \resizebox{.24\textwidth}{!}{\input{bar_hillel_repair_1}} \resizebox{.24\textwidth}{!}{\input{bar_hillel_repair_2}} \resizebox{.24\textwidth}{!}{\input{bar_hillel_repair_3}} \resizebox{.24\textwidth}{!}{\input{bar_hillel_repair_4}} % \resizebox{.24\textwidth}{!}{\input{bar_hillel_repair_5}} -%\resizebox{.3\textwidth}{!}{\input{repair1_plot.tex}} -%\resizebox{.307\textwidth}{!}{\input{repair2_plot.tex}} +%\resizebox{.3\textwidth}{!}{\input{repair1_plot}} +%\resizebox{.307\textwidth}{!}{\input{repair2_plot}} \caption{Human repair benchmarks. Note the y-axis across different edit distance plots has varying ranges. The red line indicates Seq2Parse and the orange line indicates BIFI's Precision@1 on the same repairs.}\label{fig:human} \end{figure} diff --git a/latex/thesis/Thesis.pdf b/latex/thesis/Thesis.pdf index bb2a872d..de80b136 100644 Binary files a/latex/thesis/Thesis.pdf and b/latex/thesis/Thesis.pdf differ diff --git a/latex/thesis/content/Ch3_Deterministic_Repair.tex b/latex/thesis/content/Ch3_Deterministic_Repair.tex index 8101ff5f..e5ce3d39 100644 --- a/latex/thesis/content/Ch3_Deterministic_Repair.tex +++ b/latex/thesis/content/Ch3_Deterministic_Repair.tex @@ -242,13 +242,13 @@ \section{The Bar-Hillel Construction} \section{Parikh Refinements} -To identify these superfluous triples, we define an interval domain that soundly overapproximates the Parikh image, encoding the minimum and maximum number of terminals each nonterminal can generate. Since some intervals may be right-unbounded, we write $\mathbb{N}^*=\mathbb{N} \cup \{\infty\}$ to denote the upper bound, and $\Pi = \{[a, b] \in \mathbb{N} \times \mathbb{N}^* \mid a \leq b\}^{|\Sigma|}$ to denote the Parikh image of all terminals. +To identify superfluous $q, v, q'$ triples, we define an interval domain that soundly overapproximates the Parikh image, encoding the minimum and maximum number of terminals each nonterminal must and can generate, respectively. Since some intervals may be right-unbounded, we write $\mathbb{N}^*=\mathbb{N} \cup \{\infty\}$ to denote the upper bound, and $\Pi = \{[a, b] \in \mathbb{N} \times \mathbb{N}^* \mid a \leq b\}^{|\Sigma|}$ to denote the Parikh image of all terminals. \begin{definition}[Parikh mapping of a nonterminal]\label{def:parikh} Let $p: \Sigma^*\rightarrow\mathbb{N}^{|\Sigma|}$ be the Parikh operator~\cite{parikh1966context}, which counts the frequency of terminals in a string. We define the Parikh map, $\pi: V \rightarrow \Pi$, as a function returning the smallest interval such that $\forall \sigma: \Sigma^*, \forall v: V$, $v \Rightarrow^* \sigma \vdash p(\sigma) \in \pi(v)$. \end{definition} -In other words, the Parikh mapping computes the greatest lower and least upper bound of the Parikh image over all strings in the language of a nonterminal. The infimum of a nonterminal's Parikh interval tells us how many of each terminal a nonterminal \textit{must} generate, and the supremum tells us how many it \textit{can} generate. Likewise, we define a similar relation over NFA state pairs: +The Parikh mapping computes the greatest lower and least upper bound of the Parikh image over all strings in the language of a nonterminal. The infimum of a nonterminal's Parikh interval tells us how many of each terminal a nonterminal \textit{must} generate, and the supremum tells us how many it \textit{can} generate. Likewise, we define a similar relation over NFA state pairs: \begin{definition}[Parikh mapping of NFA states] We define $\pi: Q\times Q \rightarrow \Pi$ as returning the smallest interval such that $\forall \sigma: \Sigma^*, \forall q, q': Q$, $q \overset{\sigma}{\Longrightarrow} q' \vdash p(\sigma) \in \pi(q, q')$. @@ -260,7 +260,7 @@ \section{Parikh Refinements} Given two Parikh intervals $\pi, \pi': \Pi$, we define the divergence between them as $\pi \parallel \pi' = \sum_{n=1}^{|\Sigma|} \min_{(i, i') \in \pi[n]\times \pi'[n]} |i - i'|$. \end{definition} -Now, we know that if the Parikh divergence between two intervals is nonzero, those intervals must be incompatible as no two strings, one from each Parikh interval, can be transformed into the other with fewer than $\pi \parallel \pi'$ edits. +We know that if the Parikh divergence between two intervals is nonzero, those intervals must be incompatible as no two strings, one from each Parikh interval, can be transformed into the other with fewer than $\pi \parallel \pi'$ edits. \begin{definition}[Parikh compatibility] Let $q, q'$ be NFA states and $v$ be a CFG nonterminal. We call $\langle q, v, q'\rangle: Q\times V\times Q$ \textit{compatible} iff their divergence is zero, i.e., $v \lhd qq' \iff \big(\pi(v) \parallel \pi(q, q')\big) = 0$. diff --git a/latex/thesis/content/Ch4_Probabilistic_Repair.tex b/latex/thesis/content/Ch4_Probabilistic_Repair.tex index ef594b3b..ff04a099 100644 --- a/latex/thesis/content/Ch4_Probabilistic_Repair.tex +++ b/latex/thesis/content/Ch4_Probabilistic_Repair.tex @@ -3,6 +3,39 @@ \chapter{\rm\bfseries Probabilistic Program Repair} As we have seen, the problem of program repair is highly underdetermined. To resolve this ambiguity, we will use a probabilistic model to induce a distribution over the language of valid programs. This distribution will guide the repair process by assigning a likelihood to each possible repair. Then, taking the maximum over all possible repairs, we can find the most likely repair consistent with the constraints and the observed program. -Specifically, we will define an ordering over strings by their likelihood under the probabilistic model. We then define a repair as the most likely string consistent with the observed program and the grammar. We factorize the probability of a string as the product of the probability of each token in the string, conditioned on the previous tokens. This allows us to compute the likelihood of a string in a left-to-right fashion. +Specifically, we will define an ordering over strings by their likelihood under the probabilistic model. We then define a repair as the most likely string consistent with the observed program and the grammar. We factorize the probability of a string as the product of the probability of each token in the string, conditioned on its predecessors. This allows us to compute the joint probability in a left-to-right fashion. -This probabilistic model will generally admit programs that are locally probable, but globally inconsistent with the grammar. To enforce syntactic validity, we will use the probabilistic language model to ``steer'' a generative sampler through the automaton representing the repair language. This has two advantages: first, it allows us to sample from the repair language incrementally, and second, it ensures that subsequences with high probability are retrieved first, and all trajectories are syntactically valid. \ No newline at end of file +This probabilistic model will generally admit programs that are locally probable, but globally inconsistent with the grammar. To enforce syntactic validity, we will use the probabilistic language model to ``steer'' a generative sampler through the automaton representing the repair language. This has two advantages: first, it allows us to sample from the repair language incrementally, and second, it ensures that subsequences with high probability are retrieved first, and all trajectories are syntactically valid. + +We will consider two kinds of probabilistic models: a constrained Markov model and an unconstrained transformer-based neural network trained on program repair, then evaluate the performance of these models on a syntax repair benchmark consisting of pairwise program transformations. As we will show, the constrained Markov model is able to achieve state-of-the-art precision on blind prediction of the lexical sequence. + +\begin{figure}[H] + \resizebox{.24\textwidth}{!}{\input{../popl2025/len_dist_tidy}} + \resizebox{.24\textwidth}{!}{\input{../popl2025/len_dist_bifi_all}} + \resizebox{.24\textwidth}{!}{\input{../popl2025/len_dist_s2p}} + \resizebox{.24\textwidth}{!}{\input{../popl2025/len_dist_bifi}} + \caption{Tidyparse, Seq2Parse and BIFI repair precision across length and edits.} +\end{figure} + +If we give it an equivalent number of samples, the constrained Markov model attains an even wider margin. + +\begin{figure}[H] + \resizebox{.24\textwidth}{!}{\input{../popl2025/len_dist_tidy}} + \resizebox{.24\textwidth}{!}{\input{../popl2025/len_dist_bifi_all}} + \resizebox{.24\textwidth}{!}{\input{../popl2025/len_dist_tidy200}} + \resizebox{.24\textwidth}{!}{\input{../popl2025/len_dist_tidy20k}} + \caption{Tidyparse, Seq2Parse and BIFI repair precision across length and edits.} +\end{figure} + +Now, we measure latency. + +\begin{figure}[H] +% \resizebox{.19\textwidth}{!}{\input{bar_hillel_repair.tex}} + \resizebox{.24\textwidth}{!}{\input{../popl2025/bar_hillel_repair_1}} + \resizebox{.24\textwidth}{!}{\input{../popl2025/bar_hillel_repair_2}} + \resizebox{.24\textwidth}{!}{\input{../popl2025/bar_hillel_repair_3}} +% \resizebox{.24\textwidth}{!}{\input{bar_hillel_repair_5}} +%\resizebox{.3\textwidth}{!}{\input{repair1_plot.tex}} +%\resizebox{.307\textwidth}{!}{\input{repair2_plot.tex}} + \caption{Latency benchmarks. Note the varying axis ranges. The red line marks Seq2Parse and the orange line marks BIFI's Precision@1 on the same repairs.}\label{fig:human} +\end{figure} \ No newline at end of file diff --git a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt index 08773792..844030c1 100644 --- a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt +++ b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt @@ -5,7 +5,7 @@ import ai.hypergraph.kaliningraph.parsing.* import ai.hypergraph.markovian.mcmc.MarkovChain import dk.brics.automaton.Automaton.* import dk.brics.automaton.Transition -import java.util.concurrent.* +import java.util.PriorityQueue import kotlin.random.Random import kotlin.time.* @@ -49,10 +49,13 @@ fun JAutomaton.toDot(processed: MutableSet = mutableSetOf() * previous n-1 transitions, i.e., q' ~ argmax_{q'} P(q' | q_{t-1}, ..., q_{t-n+1}) */ -data class FSATrajectory(val traj: List<Σᐩ?>, val lastState: BState, val score: Double) { +data class FSATrajectory(val traj: List<Σᐩ?>, val lastState: BState, val score: Double, val id: Int = traj.hashCode()) { val isComplete: Boolean = lastState.isAccept val tokens by lazy { traj.reversed().filterNotNull() } + fun append(tok: Σᐩ?, state: BState, score: Double) = + FSATrajectory(listOf(tok) + traj, state, score, id * 31 + tok.hashCode()) override fun toString() = tokens.joinToString(" ") + override fun equals(other: Any?): Boolean = other is FSATrajectory && id == other.id } fun BAutomaton.min(): BAutomaton = minimize(this) @@ -65,14 +68,15 @@ fun PTree.toDFA( ) = measureTimedValue { BAutomaton.setMinimization(MINIMIZE_BRZOZOWSKI) + val period = 5 var i = 0 var j = 0 propagator( both = { a, b -> if (a == null) b else if (b == null) a // Only periodically minimize the automata during construction - else if (i++ % 13 == 0) a.concatenate(b).min() else a.concatenate(b) }, + else if (i++ % period == 0) a.concatenate(b).min() else a.concatenate(b) }, either = { a, b -> if (a == null) b else if (b == null) a - else if (j++ % 13 == 0) a.union(b).min() else a.union(b) }, + else if (j++ % period == 0) a.union(b).min() else a.union(b) }, unit = { a -> if ("ε" in a.root) null else unitRule(a.root) } ) }.also { println("Took ${it.duration} to build FSA") }.value @@ -98,48 +102,39 @@ fun BAutomaton.decodeDFA( callback: (Σᐩ) -> Unit = {}, topK: Int = 10_000_000, // Total number of top-K results to return timeout: Duration = Duration.INFINITE, - parallelize: Boolean = false ): List<Σᐩ> { val startTime = TimeSource.Monotonic.markNow() val load = 100_000 - val fullTrajectories = PriorityBlockingQueue(load, compareBy { it.score / it.traj.size }) - val partTrajectories = Array(if(parallelize) NUM_CORES else 1) { - PriorityBlockingQueue(load, compareBy { it.score / it.traj.size }) + val fullTrajectories = PriorityQueue(load, compareBy { it.score / it.traj.size }) + val partTrajectories = + PriorityQueue(load, compareBy { it.score / it.traj.size }) .apply { add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0)) } - } - fun task(id: Int = 0) { - var i = 0 while ( fullTrajectories.size < topK && - partTrajectories.any { it.size > 0 } && + partTrajectories.size > 0 && startTime.elapsedNow() < timeout ) { - if (partTrajectories[id].isEmpty()) continue -// Checks for balanced distribution of work across cores -// if (i++ % 9999 == 0) println("Trajectories[$id]: ${partTrajectories.map {it.size}}") - val partTraj = partTrajectories[id].remove() + val partTraj = partTrajectories.poll() val lastToks = partTraj.traj.take(mc.memory - 1).reversed() - partTraj.lastState.transitions.forEach { next: Transition -> - (next.min..next.max).forEach { tok -> + partTraj.lastState.transitions.flatMap { next -> + (next.min..next.max).map { tok -> val decTok = dec[tok] - val nextToks = lastToks + decTok - val nextScore = partTraj.score + mc.scoreChunk(nextToks) - val traj = FSATrajectory(listOf(decTok) + partTraj.traj, next.dest, nextScore) - val bin = if (parallelize) Random(traj.score.hashCode()).nextInt(NUM_CORES) else 0 - if (!traj.isComplete) partTrajectories[bin].add(traj) + val nextScore = partTraj.score + mc.scoreChunk(lastToks + decTok) + + Triple(next, decTok, nextScore) + } + } +// .sortedBy { (_, _, nextScore) -> -nextScore }.take(100) + .forEach { (next: Transition, decTok: String?, nextScore: Double) -> + val traj = partTraj.append(decTok, next.dest, nextScore) + if (!traj.isComplete) { partTrajectories.add(traj) } else { - fullTrajectories.add(traj) - callback(traj.toString()) - if (traj.lastState.transitions.isNotEmpty()) - partTrajectories[bin].add(traj) + fullTrajectories.add(traj.also { callback(it.toString()) }) + if (traj.lastState.transitions.isNotEmpty()) partTrajectories.add(traj) } } } - } - } - - if (parallelize) (0.., + dec: Map, // Maps unicode characters back to strings + callback: (Σᐩ) -> Unit = {}, + topK: Int = 10_000_000, // Total number of top-K results to return + timeout: Duration = Duration.INFINITE, + beamWidth: Int = 100_000, // Maximum number of trajectories to keep at each step +): List<Σᐩ> { + val startTime = TimeSource.Monotonic.markNow() + val fullTrajectories = PriorityQueue(compareBy { it.score / it.traj.size }) // Max-heap for full trajectories + val beam = PriorityQueue(compareBy { it.score / it.traj.size }) // Beam for partial trajectories + + beam.add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0)) + + while ( + fullTrajectories.size < topK && + beam.isNotEmpty() && + startTime.elapsedNow() < timeout + ) { + val nextBeam = PriorityQueue(compareBy { it.score / it.traj.size }) + + while (beam.isNotEmpty() && startTime.elapsedNow() < timeout) { + val partTraj = beam.poll() + val lastToks = partTraj.traj.take(mc.memory - 1).reversed() + + partTraj.lastState.transitions.flatMap { next -> + (next.min..next.max).map { tok -> + val decTok = dec[tok] + val nextScore = partTraj.score + mc.scoreChunk(lastToks + decTok) + partTraj.append(decTok, next.dest, nextScore) + } + }.forEach { traj -> + if (traj.isComplete) { + if (traj.lastState.transitions.isNotEmpty()) nextBeam.add(traj) + fullTrajectories.add(traj) + callback(traj.toString()) + } else { + nextBeam.add(traj) + } + } + } + + beam.clear() + beam.addAll(nextBeam.take(beamWidth)) + } + + val deduped = fullTrajectories.map { it.toString() }.distinct().toList() + println("Took ${startTime.elapsedNow()} to decode ${deduped.size} trajectories, with ${beam.size} in queue") return deduped } diff --git a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt index 85de6d57..0d2e1372 100644 --- a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt +++ b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt @@ -183,6 +183,9 @@ fun computeNTCompat(cfg: CFG, levStr: List<Σᐩ>): Array>> return arr } +var filterMs = 0L +var normMs = 0L + // We pass pm and lbc because cache often flushed forcing them to be reloaded // and we know they will usually be the same for all calls to this function. fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CFG { @@ -220,7 +223,7 @@ fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CF .filter { it: Π3 -> // Checks whether the distinct subtrajectory between two horizontal states is parseable by a given NT fsa.compat(it.π1, it.π2, it.π3, compat) - // Checks whether the length bounds for the noterminal (i.e., the range of the number of terminals it can + // Checks whether the length bounds for the nonterminal (i.e., the range of the number of terminals it can // parse) is compatible with the range of path lengths across all paths connecting two states in an FSA. // This is a coarse approximation, but is cheaper to compute, so it filters out most invalid triples. && parikhMap.ntLengthBounds[it.π3].overlaps( @@ -270,6 +273,8 @@ fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CF val totalProds = binaryProds.size + transits.size + unitProds.size + initFinal.size println("Constructed ∩-grammar with $totalProds productions in ${clock.elapsedNow()}") + filterMs += clock.elapsedNow().inWholeMilliseconds + clock = TimeSource.Monotonic.markNow() return Stream.concat(binaryProds.stream(), (initFinal + transits + unitProds).stream()).parallel() // A production, e.g., * -> * [G], can be removed if the synthetic nonterminal [G] does not exist, i.e., @@ -279,6 +284,10 @@ fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CF .collect(Collectors.toSet()) .also { println("Eliminated ${totalProds - it.size} extra productions before normalization") } .jvmPostProcess(clock) + .also { + normMs += clock.elapsedNow().inWholeMilliseconds + println("Fraction of time spent normalizing: " + (normMs)/(normMs.toDouble() + filterMs)) + } // .expandNonterminalStubs(origCFG = this@jvmIntersectLevFSAP) // .jdvpNew() } diff --git a/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt b/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt index 10a39cdf..a1d3b3ce 100644 --- a/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt +++ b/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt @@ -126,7 +126,7 @@ class WFSATest { val ptreeRepairs = measureTimedValue { pt.sampleStrWithoutReplacement().distinct().take(maxResults).toSet() } - measureTimedValue { pt.toDFA()!!.decodeDFA(P_BIFI_PY150, dec = pt.termDict, parallelize = true) }.also { + measureTimedValue { pt.toDFA()!!.decodeDFA(P_BIFI_PY150, dec = pt.termDict) }.also { assertTrue(groundTr in it.value, "Ground truth not found in ${it.value.size} repairs") println("Index: ${it.value.indexOf(groundTr)}") // // Print side by side comparison of repairs