From 20a2afe488e7a23700c27bd25841300c205bee99 Mon Sep 17 00:00:00 2001 From: breandan Date: Thu, 20 Jun 2024 23:41:10 -0400 Subject: [PATCH] steerable random walk through dfa --- .github/workflows/main.yml | 4 ++-- .../ai/hypergraph/markovian/mcmc/MarkovChain.kt | 17 ++++++++++++----- .../kaliningraph/automata/WFSATest.kt | 17 ++++++++++++++++- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6584cfee..7ff2621f 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -13,9 +13,9 @@ jobs: steps: - uses: actions/checkout@v2 - - name: Set up JDK 17 + - name: Set up JDK 21 uses: actions/setup-java@v1 with: - java-version: 17 + java-version: 21 - name: Build with Gradle run: ./gradlew -PleaseExcludeBenchmarks allTests --stacktrace \ No newline at end of file diff --git a/src/jvmMain/kotlin/ai/hypergraph/markovian/mcmc/MarkovChain.kt b/src/jvmMain/kotlin/ai/hypergraph/markovian/mcmc/MarkovChain.kt index 21039645..43284d12 100644 --- a/src/jvmMain/kotlin/ai/hypergraph/markovian/mcmc/MarkovChain.kt +++ b/src/jvmMain/kotlin/ai/hypergraph/markovian/mcmc/MarkovChain.kt @@ -119,7 +119,7 @@ open class MarkovChain( val dists: LRUCache, Dist> = LRUCache() // Computes perplexity of a sequence normalized by sequence length (lower is better) - fun score(seq: List): Double = + fun score(seq: List): Double = if (memory < seq.size) -seq.windowed(memory) .map { (getAtLeastOne(it) + 1) / (getAtLeastOne(it.dropLast(1) + null) + dictionary.size) } .sumOf { ln(it) } / seq.size @@ -257,14 +257,21 @@ open class MarkovChain( var total = 0L lines.map { it.substringBefore(CSVSEP).split(" ") to it.substringAfter(CSVSEP).toLong() } .forEach { (ngram, count) -> - total += count - nrmCounts.update(ngram, count) + val padding = List(memory - 1) { null } + val windows = (padding + ngram + padding).windowed(memory, 1) + total += count * windows.size + windows.forEach { nrmCounts.update(it, count) } ngram.forEach { rawCounts.update(it, count) } } return MarkovChain( - train = sequenceOf(), + train = sequenceOf(), // Empty since we already know the counts, no need to retrain memory = memory, - Counter(total = AtomicInteger(total.toInt()), memory = memory, rawCounts = rawCounts, nrmCounts = nrmCounts) + Counter( + total = AtomicInteger(total.toInt()), + memory = memory, + rawCounts = rawCounts, + nrmCounts = nrmCounts + ) ) } } diff --git a/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt b/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt index e0125a11..18a85e0f 100644 --- a/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt +++ b/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt @@ -4,6 +4,7 @@ import Grammars import Grammars.shortS2PParikhMap import ai.hypergraph.kaliningraph.graphs.LabeledGraph import ai.hypergraph.kaliningraph.parsing.* +import ai.hypergraph.markovian.mcmc.MarkovChain import net.jhoogland.jautomata.* import net.jhoogland.jautomata.Automaton import net.jhoogland.jautomata.operations.* @@ -62,6 +63,20 @@ class WFSATest { .replace("Mrecord", "circle") // FSA states should be circular .replace("null", "ε") // null label = ε-transition + /* + * Returns a sequence trajectories through a DFA sampled using the Markov chain. + * The DFA is expected to be deterministic. We use the Markov chain to steer the + * random walk through the DFA by sampling the best transitions conditioned on the + * previous n-1 transitions, i.e., q' ~ argmax_{q'} P(q' | q_{t-1}, ..., q_{t-n+1}) + */ + + fun Automaton.randomWalk(mc: MarkovChain, topK: Int = 1000): Sequence { + val init = initialStates().first() + val padding = List(mc.memory - 1) { null } + val ts = transitionsOut(init).map { (it as BasicTransition).label() }.map { it to mc.score(padding + it) } + return TODO() + } + /* ./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.automata.WFSATest.testPTreeVsWFSA" */ @@ -70,7 +85,7 @@ class WFSATest { val toRepair = "NAME : NEWLINE NAME = STRING NEWLINE NAME = NAME . NAME ( STRING ) NEWLINE" val radius = 2 val pt = Grammars.seq2parsePythonCFG.makeLevPTree(toRepair, radius, shortS2PParikhMap) - println(pt.totalTrees.toString()) + println("Total trees: " + pt.totalTrees.toString()) val maxResults = 10_000 val ptreeRepairs = measureTimedValue { pt.sampleStrWithoutReplacement().distinct().take(maxResults).toSet()