Skip to content

Commit

Permalink
implement intersection parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jan 14, 2025
1 parent 54a5ea0 commit b9c727e
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 5 deletions.
53 changes: 53 additions & 0 deletions src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/FSA.kt
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,59 @@ open class FSA constructor(open val Q: TSA, open val init: Set<Σᐩ>, open val

return false
}

fun intersectPTree(str: Σᐩ, cfg: CFG, radius: Int): PTree? {
// 1) Build the Levenshtein automaton (acyclic)
val levFSA = makeLevFSA(str, radius)

val nStates = levFSA.numStates
val startIdx = cfg.bindex[START_SYMBOL]

// 2) Create dp array of parse trees
val dp: Array<Array<Array<PTree?>>> = Array(nStates) { Array(nStates) { Array(cfg.nonterminals.size) { null } } }

// 3) Initialize terminal productions A -> a
for ((p, Aidx, q) in levFSA.allIndexedTxs(cfg)) {
for (t in cfg.bimap.UNITS[cfg.bindex[Aidx]]!!) {
val Aname = cfg.bindex[Aidx]

val newLeaf = PTree(root = Aname, branches = PSingleton(t))
dp[p][q][Aidx] = if (dp[p][q][Aidx] == null) newLeaf else dp[p][q][Aidx]!! + newLeaf
}
}

for (dist in 1 until nStates) {
for (p in 0 until (nStates - dist)) {
val q = p + dist

// For each rule A -> B C
for ((Aidx, Bidx, Cidx) in cfg.tripleIntProds) {
val Aname = cfg.bindex[Aidx]

// Check all possible midpoint states r in the DAG from p to q
for (r in (levFSA.allPairs[p to q] ?: emptySet())) {
val left = dp[p][r][Bidx]
val right = dp[r][q][Cidx]
if (left != null && right != null) {
// Found a parse for A
val newBranch = (left to right)
val newTree = PTree(Aname, listOf(newBranch))

if (dp[p][q][Aidx] == null) dp[p][q][Aidx] = newTree
else dp[p][q][Aidx] = dp[p][q][Aidx]!! + newTree
}
}
}
}
}

// 5) Gather final parse trees from dp[0][f][startIdx], for all final states f
val allParses = levFSA.finalIdxs.mapNotNull { f -> dp[0][f][startIdx] }

// 6) Combine them under a single "super‐root"
return if (allParses.isEmpty()) null
else PTree(START_SYMBOL, allParses.flatMap { forest -> forest.branches })
}
}

fun walk(from: Σᐩ, next: (Σᐩ, List<Σᐩ>) -> Int): List<Σᐩ> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ val CFG.unicodeMap by cache { terminals.associateBy { Random(it.hashCode()).next
val CFG.ntLst by cache { (symbols + "ε").toList() }
val CFG.ntMap by cache { ntLst.mapIndexed { i, s -> s to i }.toMap() }

val CFG.tripleIntProds: Set<Π3A<Int>> by cache { filter { it.RHS.size == 2 }.map { bindex[it.LHS] to bindex[it.RHS[0]] to bindex[it.RHS[1]] }.toSet() }
val CFG.tripleIntProds: Set<Π3A<Int>> by cache { bimap.TRIPL.map { (a, b, c) -> bindex[a] to bindex[b] to bindex[c] }.toSet() }
val CFG.revUnitProds: Map<Σᐩ, List<Int>> by cache { terminals.associate { it to bimap[listOf(it)].map { bindex[it] } } }

// Maps each nonterminal to the set of nonterminal pairs that can generate it,
Expand Down Expand Up @@ -249,7 +249,7 @@ class Bindex<T>(
override fun toString(): String = indexedNTs.mapIndexed { i, it -> "$i: $it" }.joinToString("\n", "Bindex:\n", "\n")
}
// Maps variables to expansions and expansions to variables in a grammar
class BiMap(cfg: CFG) {
class BiMap(val cfg: CFG) {
val L2RHS by lazy { cfg.groupBy({ it.LHS }, { it.RHS }).mapValues { it.value.toSet() } }
val R2LHS by lazy { cfg.groupBy({ it.RHS }, { it.LHS }).mapValues { it.value.toSet() } }

Expand All @@ -265,11 +265,11 @@ class BiMap(cfg: CFG) {
getOrPut(l) { mutableSetOf() }.add(symbol)
}
}
val TRIPL by lazy {
val TRIPL: List<Π3A<Σᐩ>> by lazy {
R2LHS.filter { it.key.size == 2 }
.map { it.value.map { v -> v to it.key[0] to it.key[1] } }.flatten()
}
val X2WZ: Map<Σᐩ, List<Triple<Σᐩ, Σᐩ, Σᐩ>>> by lazy {
val X2WZ: Map<Σᐩ, List<Π3A<Σᐩ>>> by lazy {
TRIPL.groupBy { it.second }.mapValues { it.value }
}
val UNITS by lazy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ import kotlin.time.measureTimedValue
// Indexes a set of PTrees by their roots
typealias PForest = Map<String, PTree> // ℙ₃
// Algebraic data type / polynomial functor for parse forests (ℙ₂)
class PTree(val root: String = "", val branches: List<Π2A<PTree>> = listOf()) {
class PTree constructor(val root: String = "", val branches: List<Π2A<PTree>> = listOf()) {
// val hash by lazy { root.hashCode() + if (branches.isEmpty()) 0 else branches.hashCode() }
// override fun hashCode(): Int = hash
var ntIdx = -1

operator fun plus(other: PTree) = PTree(root, branches + other.branches)

val branchRatio: Pair<Double, Double> by lazy { if (branches.isEmpty()) 0.0 to 0.0 else
(branches.size.toDouble() + branches.sumOf { (l, r) -> l.branchRatio.first + r.branchRatio.first }) to
(1 + branches.sumOf { (l, r) -> l.branchRatio.second + r.branchRatio.second })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,12 @@ class BarHillelTest {
val led = Grammars.dyck.LED(str)
println("Language edit distance: $led")

measureTime {
FSA.intersectPTree(str, Grammars.dyck, led)!!.sampleStrWithoutReplacement()
.take(100).toList().also { assertTrue { it.isNotEmpty() } }
.forEach { assertTrue(it in Grammars.dyck.language) }
}.also { println("Enumeration took: $it") }

measureTimedValue { FSA.nonemptyLevInt(str, Grammars.dyck, led) }
.also { println("${it.value} / ${it.duration}") }
.also { assertTrue(it.value) }
Expand Down

0 comments on commit b9c727e

Please sign in to comment.