Skip to content

Commit

Permalink
simplify CNF renormalization step
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Dec 13, 2024
1 parent 5a139fd commit 3384ffa
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,11 @@ fun ptreeUnion(left: List<PTree?>, right: List<PTree?>): List<PTree?> =
}

val CFG.bitwiseAlgebra: Ring<Blns> by cache {
vindex.let {
vindex.let { vi ->
Ring.of(
nil = BooleanArray(nonterminals.size) { false },
plus = { x, y -> union(x, y) },
times = { x, y -> fastJoin(it, x, y) }
times = { x, y -> fastJoin(vi, x, y) },
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,13 +403,11 @@ open class UTMatrix<T> constructor(
algebra = algebra
)
else carry.windowed(2, 1).map { window ->
window[0].second.zip(window[1].third)
.map { (l, r) -> with(algebra) { l * r } }
.fold(algebra.nil) { t, acc -> with(algebra) { acc + t } }
.let { it to (window[0].second + it) to (listOf(it) + window[1].third) }
with(algebra) { dot(window[0].π2, window[1].π3) }
.let { it to (window[0].π2 + it) to (listOf(it) + window[1].π3) }
}.let { next ->
UTMatrix(
diagonals = diagonals + listOf(next.map { it.first }),
diagonals = diagonals + listOf(next.map { it.π1 }),
algebra = algebra
).seekFixpoint(next, iteration + 1, maxIterations)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ interface Group<T>: Nat<T> {
interface Ring<T>: Group<T> {
override fun T.plus(t: T): T
override fun T.times(t: T): T
fun dot(l1: List<T>, l2: List<T>): T = l1.zip(l2).map { (l, r) -> l * r }.reduce { acc, t -> acc + t }

open class of<T>(
override val nil: T, override val one: T = nil,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import kotlin.streams.*
import kotlin.time.Duration.Companion.minutes
import kotlin.time.TimeSource
import java.util.concurrent.ConcurrentHashMap
import kotlin.collections.asSequence

fun CFG.parallelEnumSeqMinimalWOR(
prompt: List<String>,
Expand Down Expand Up @@ -284,35 +285,24 @@ fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CF
// .jdvpNew()
}

// Parallel streaming doesn't seem to be that much faster (yet)?

fun CFG.jvmPostProcess(clock: TimeSource.Monotonic.ValueTimeMark) =
jvmElimVarUnitProds(
cfg = jvmDropVestigialProductions(clock)
).also { println("Normalization eliminated ${size - it.size} productions in ${clock.elapsedNow()}") }
.freeze()

tailrec fun jvmElimVarUnitProds(
cfg: CFG,
toVisit: Set<Σᐩ> = cfg.nonterminals,
vars: Set<Σᐩ> = cfg.nonterminals,
toElim: Σᐩ? = toVisit.firstOrNull()
): CFG {
fun Production.isVariableUnitProd() = RHS.size == 1 && RHS[0] in vars
if (toElim == null) return cfg.filter { !it.isVariableUnitProd() }
val varsThatMapToMe =
cfg.asSequence().asStream().parallel()
.filter { it.RHS.size == 1 && it.RHS[0] == toElim }
.map { it.LHS }.collect(Collectors.toSet())
val thingsIMapTo =
cfg.asSequence().asStream().parallel()
.filter { it.LHS == toElim }.map { it.RHS }
.collect(Collectors.toSet())
return jvmElimVarUnitProds(
(varsThatMapToMe * thingsIMapTo).fold(cfg) { g, p -> g + p },
toVisit.drop(1).toSet(),
vars
)
jvmDropVestigialProductions(clock)
.also { println("Normalization eliminated ${size - it.size} productions in ${clock.elapsedNow()}") }

// Eliminates unit productions whose RHS is not a terminal. For Bar-Hillel intersections, we know the only
// examples of this are the (S -> *) rules, so elimination is much simpler than the full CNF normalization.
fun jvmElimVarUnitProds(cfg: CFG): CFG {
val scfg = cfg.asSequence()
val vars = scfg.asStream().parallel().map { it.first }.collect(Collectors.toSet())
val toElim = scfg.asStream().parallel()
.filter { it.RHS.size == 1 && it.LHS == "START" && it.RHS[0] in vars }
.map { it.RHS[0] }
.collect(Collectors.toSet())
val newCFG = scfg.asStream().parallel()
.filter { it.RHS.size > 1 || it.RHS[0] !in toElim }
.map { if (it.LHS in toElim) "START" to it.RHS else it }
.collect(Collectors.toSet())
return newCFG
}

// TODO: Incomplete / untested
Expand Down Expand Up @@ -394,9 +384,9 @@ tailrec fun jvmElimVarUnitProds(
fun CFG.jvmDropVestigialProductions(clock: TimeSource.Monotonic.ValueTimeMark): CFG {
val start = clock.elapsedNow()
var counter = 0
val nts: Set<Σᐩ> = asSequence().asStream().parallel().map { it.first }.collect(Collectors.toSet())
val rw = asSequence().asStream().parallel()
.filter { prod ->
val scfg = asSequence()
val nts: Set<Σᐩ> = scfg.asStream().parallel().map { it.first }.collect(Collectors.toSet())
val rw = scfg.asStream().parallel().filter { prod ->
if (counter++ % 1000 == 0 && BH_TIMEOUT < clock.elapsedNow()) throw Exception("Timeout! ${clock.elapsedNow()}")
// Only keep productions whose RHS symbols are not synthetic or are in the set of NTs
prod.RHS.all { !(it.first() == '[' && 1 < it.length) || it in nts }
Expand All @@ -405,11 +395,11 @@ fun CFG.jvmDropVestigialProductions(clock: TimeSource.Monotonic.ValueTimeMark):
.also { println("Removed ${size - it.size} invalid productions in ${clock.elapsedNow() - start}") }
.freeze()
.jvmRemoveUselessSymbols(nts)
// .jdvpNew()

println("Removed ${size - rw.size} vestigial productions, resulting in ${rw.size} productions.")

return if (rw.size == size) rw else rw.jvmDropVestigialProductions(clock)
return if (rw.size == size) jvmElimVarUnitProds(rw).freeze()
else rw.jvmDropVestigialProductions(clock)
}

/**
Expand Down Expand Up @@ -472,7 +462,7 @@ private fun CFG.jvmGenSym(
val nextGenerating: MutableSet<Σᐩ> = from.toMutableSet()
val TDEPS =
ConcurrentHashMap<Σᐩ, MutableSet<Σᐩ>>(size).apply {
this@jvmGenSym.asSequence().asStream().parallel()
this@jvmGenSym.toHashSet().asSequence().asStream().parallel()
.forEach { (l, r) -> r.forEach { getOrPut(it) { ConcurrentHashMap.newKeySet() }.add(l) } }
}
// [email protected]().asStream().parallel()
Expand Down

0 comments on commit 3384ffa

Please sign in to comment.