Skip to content

Commit

Permalink
profiling and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Dec 13, 2024
1 parent 0220d53 commit 5a139fd
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,12 @@ val CFG.ntMap by cache { ntLst.mapIndexed { i, s -> s to i }.toMap() }
// which is then flattened to a list of adjacent pairs of nonterminal indices
val CFG.vindex: Array<IntArray> by cache {
Array(bindex.indexedNTs.size) { i ->
bimap[bindex[i]].filter { it.size > 1 }
.flatMap { listOf(bindex[it[0]], bindex[it[1]]) }.toIntArray()
// val lhs = bindex[i]
bimap[bindex[i]].filter { it.size == 2 }
// .map { it to -(PCFG3_BIFI[lhs to it[0] to it[1]] ?: 0).also { s -> println("$lhs -> ${it[0]} ${it[1]} ($s)" )} }
// .sortedBy { it.second }.map { it.first }
.map { it.map { bindex[it] } }.flatten()
.toIntArray()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import ai.hypergraph.kaliningraph.parsing.noEpsilonOrNonterminalStubs
import ai.hypergraph.kaliningraph.parsing.noNonterminalStubs
import ai.hypergraph.kaliningraph.parsing.parseCFG

val s2pCFGStr = """
val s2pCFGStr = """
START -> Stmts_Or_Newlines
Stmts_Or_Newlines -> Stmt_Or_Newline | Stmt_Or_Newline Stmts_Or_Newlines
Stmt_Or_Newline -> Stmt | Newline
Expand Down
42 changes: 6 additions & 36 deletions src/commonMain/kotlin/ai/hypergraph/kaliningraph/tensor/Tensor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,9 @@ operator fun DoubleMatrix.times(value: Double): DoubleMatrix =
DoubleMatrix(numRows, numCols, data.map { it * value })

// TODO: Rewrite this from scratch using T: List<UTMatrix<T>> recursive type with overlapping trees
// Diagonals of a strictly-UT matrix for DAG-based dynamic programming
class UTMatrix<T> constructor(
// Diagonals of a strictly upper triangular matrix for DAG-based dynamic programming
// All lower diagonal and diagonal entries are zero
open class UTMatrix<T> constructor(
val diagonals: List<List<T>>, // List of strictly-UT diagonals from longest to shortest
override val algebra: Ring<T>
): AbstractMatrix<T, Ring<T>, UTMatrix<T>>(algebra, diagonals.first().size + 1) {
Expand Down Expand Up @@ -382,7 +383,9 @@ class UTMatrix<T> constructor(

fun squared() = toFullMatrix().squareUpperTriangular().toUTMatrix()

fun seekFixpoint(
// Performs matrix-matrix multiplication until the fixpoint is reached
// This basically fills up each diagonal until the last upper diagonal
open fun seekFixpoint(
// Carries a triple of:
// (1) the element itself,
// (2) row to an element's left (inclusive)
Expand Down Expand Up @@ -411,39 +414,6 @@ class UTMatrix<T> constructor(
).seekFixpoint(next, iteration + 1, maxIterations)
}

fun seekFixpointFast(maxIterations: Int = diagonals.first().size): UTMatrix<T> {
var iteration = 0

val diagonalsMutable = diagonals.toMutableList()
val carry = diagonals.last().map { it to mutableListOf(it) to mutableListOf(it) }.toMutableList()

while (iteration < maxIterations && diagonalsMutable.last().size != 1) {
val next = mutableListOf<Triple<T, MutableList<T>, MutableList<T>>>()

for (i in 1 until carry.size) {
var acc = algebra.nil
for (j in carry[i - 1].second.indices) {
acc = with(algebra) { acc + (carry[i - 1].second[j] * carry[i].third[j]) }
}

val left = carry[i - 1].second.apply { add(acc) }
val right = carry[i].third.apply { add(0, acc) }

next.add(Triple(acc, left, right))
}

diagonalsMutable += next.map { it.first }
carry.clear()
carry.addAll(next)
iteration++
}

return UTMatrix(
diagonals = diagonalsMutable,
algebra = algebra
)
}

// Offsets diagonals by one when converting back to matrix (superdiagonal)
fun toFullMatrix() =
if (diagonals.last().size != 1)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import ai.hypergraph.kaliningraph.parsing.*

object Grammars {
val sss = """START -> b | START | START START | START START START"""
.parseCFG().noNonterminalStubs
val sss by lazy {
"""START -> b | START | START START | START START START"""
.parseCFG().noNonterminalStubs
}

val ifThen = """
val ifThen by lazy {
"""
START -> X
X -> I | F | P | Q
P -> I O I | P O I
Expand All @@ -18,33 +21,34 @@ object Grammars {
BO -> and | or | xor | nand
N -> !
""".parseCFG().noNonterminalStubs
}

val toyArith = """
val toyArith by lazy { """
S -> S + S | S * S | S - S | S / S | ( S ) | - S
S -> 0 | 1 | 2 | 3 | 4
S -> X | Y | Z
""".parseCFG().noNonterminalStubs
""".parseCFG().noNonterminalStubs }

val dyckUnambig = """S -> ( S ) S | ( S ) | ( ) S | ( )""".parseCFG().noEpsilonOrNonterminalStubs
val dyck = """S -> ( S ) | ( ) | S S""".parseCFG().noEpsilonOrNonterminalStubs
val dyckUnambig by lazy { """S -> ( S ) S | ( S ) | ( ) S | ( )""".parseCFG().noEpsilonOrNonterminalStubs }
val dyck by lazy { """S -> ( S ) | ( ) | S S""".parseCFG().noEpsilonOrNonterminalStubs }

val dyckEmbedded = """
val dyckEmbedded by lazy { """
START -> ( ) | ( START ) | START START
START -> START + START | START * START
START -> 1
""".parseCFG().noNonterminalStubs
""".parseCFG().noNonterminalStubs}

val deadSimple = """S -> ( ) | ( S )""".parseCFG().noEpsilonOrNonterminalStubs
val dsNorm = """
val deadSimple by lazy { """S -> ( ) | ( S )""".parseCFG().noEpsilonOrNonterminalStubs }
val dsNorm by lazy { """
START -> START START
START -> A B
START -> A C
A -> (
B -> )
C -> START B
""".parseCFG().noEpsilonOrNonterminalStubs
""".parseCFG().noEpsilonOrNonterminalStubs }

val ocamlCFG = """
val ocamlCFG by lazy { """
S -> X
X -> A | V | ( X , X ) | X X | ( X )
A -> FUN | F | LI | M | L
Expand All @@ -69,9 +73,9 @@ object Grammars {
VO -> = | < | `||` | `&&`
I -> 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
B -> true | false
""".parseCFG().noNonterminalStubs
""".parseCFG().noNonterminalStubs }

val coarsenedPythonCFG = """
val coarsenedPythonCFG by lazy { """
S -> w | S ( S ) | ( ) | S = S | S . S | S S | ( S ) | [ S ] | { S } | : | * S | [ ]
S -> S , S | S ; S | S : S
S -> S IOP S | S BOP S
Expand All @@ -86,9 +90,9 @@ object Grammars {
S -> if S | if S else S | return S
S -> not S | S or S
S -> lambda w : S | lambda w , w : S | lambda w , w , w : S | lambda w , w , w , w : S
""".parseCFG().noNonterminalStubs
""".parseCFG().noNonterminalStubs }

val tinyC: CFG = """
val tinyC: CFG by lazy { """
START -> program
program -> statement
statement -> if paren_expr statement
Expand All @@ -103,10 +107,10 @@ object Grammars {
test -> sum | sum < sum
sum -> term | sum + term | sum - term
term -> id | int | paren_expr
""".parseCFG().freeze()
""".parseCFG().freeze() }

// https://aclanthology.org/2020.conll-1.41.pdf#page=12
val hardestCFL: CFG = """
val hardestCFL: CFG by lazy { """
S' -> R ${'$'} Q S L ;
L -> L' , U
L' -> , V L'
Expand All @@ -131,10 +135,10 @@ object Grammars {
T -> [ Q S Q ]
T -> ( Q )
T -> [ Q ]
""".trimIndent().parseCFG().noNonterminalStubs
""".trimIndent().parseCFG().noNonterminalStubs }

val shortS2PParikhMap by lazy { ParikhMap(seq2parsePythonCFG, 20) }
val seq2parsePythonCFGStr = """
val seq2parsePythonCFGStr by lazy { """
START -> Stmts_Or_Newlines
Stmts_Or_Newlines -> Stmt_Or_Newline | Stmt_Or_Newline Stmts_Or_Newlines
Stmt_Or_Newline -> Stmt | Newline
Expand Down Expand Up @@ -325,12 +329,12 @@ object Grammars {
Yield_Expr -> Yield_Keyword | Yield_Keyword Yield_Arg
Yield_Arg -> From_Keyword Test | Testlist_Endcomma
"""
""" }

val seq2parsePythonCFG: CFG = seq2parsePythonCFGStr.parseCFG().noNonterminalStubs
val seq2parsePythonVanillaCFG: CFG = seq2parsePythonCFGStr.parseCFG().noEpsilonOrNonterminalStubs
val seq2parsePythonCFG: CFG by lazy { seq2parsePythonCFGStr.parseCFG().noNonterminalStubs }
val seq2parsePythonVanillaCFG: CFG by lazy { seq2parsePythonCFGStr.parseCFG().noEpsilonOrNonterminalStubs }

val checkedArithCFG = """
val checkedArithCFG by lazy { """
START -> S
S -> S1 = S1
S -> S2 = S2
Expand Down Expand Up @@ -381,13 +385,13 @@ P6 -> P6 / P1
P7 -> P7 / P1
P8 -> P8 / P1
P9 -> P9 / P1
""".parseCFG().noNonterminalStubs.freeze()
""".parseCFG().noNonterminalStubs.freeze() }

val arith = """
val arith by lazy { """
O -> + | * | - | /
S -> S O S | ( S )
S -> 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
""".parseCFG()
""".parseCFG() }

private fun Tree.middle(): Σᐩ? = children.drop(1).firstOrNull()?.terminal
fun Tree.evalArith(): Int = when {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ import kotlin.time.*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.parsing.SetValiantTest"
*/
class SetValiantTest {
/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.parsing.SetValiantTest.testStressRecognizer"
*/
@Test
fun testStressRecognizer() {
val g = Grammars.seq2parsePythonVanillaCFG
g.sliceSample(20).take(10000).forEach {
assertTrue(it.matches(g))
assertFalse(it.tokenizeByWhitespace().dropLastWhile { it == "DEDENT" || it == "NEWLINE" }.matches(g))
}
}

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.parsing.SetValiantTest.testSimpleGrammar"
*/
Expand Down

0 comments on commit 5a139fd

Please sign in to comment.