From fdd9f32b15a146c765d4dd3e29701802815804c1 Mon Sep 17 00:00:00 2001 From: Ryan Xiao Wei Lin Date: Wed, 8 Jan 2025 23:47:52 +0800 Subject: [PATCH] DP commit --- .../log_prob_integration.jl | 360 ++++++++++++++---- 1 file changed, 277 insertions(+), 83 deletions(-) diff --git a/src/experimental/ProbabilisticGraphicalModels/log_prob_integration.jl b/src/experimental/ProbabilisticGraphicalModels/log_prob_integration.jl index 9ac71ed4..bca30c3c 100644 --- a/src/experimental/ProbabilisticGraphicalModels/log_prob_integration.jl +++ b/src/experimental/ProbabilisticGraphicalModels/log_prob_integration.jl @@ -3,7 +3,7 @@ using Distributions using Printf ############################################################################### -# 1) Make a mutable struct and store :discrete or :continuous at creation time +# 1) BayesianNetwork definition (mutable + node_types) ############################################################################### mutable struct BayesianNetwork{V,T,F} @@ -20,9 +20,6 @@ mutable struct BayesianNetwork{V,T,F} node_types::Vector{Symbol} # e.g. :discrete or :continuous end -""" -Create an empty BayesianNetwork with Symbol variable names and Int node IDs. -""" function BayesianNetwork{V}() where {V} return BayesianNetwork( SimpleDiGraph{Int}(), @@ -35,23 +32,19 @@ function BayesianNetwork{V}() where {V} Int[], BitVector(), BitVector(), - Symbol[], # store node_types in parallel + Symbol[], ) end ############################################################################### -# 2) Add Node & Edge Helpers +# 2) Graph Helpers ############################################################################### -""" -Add a stochastic vertex with name `name`, a distribution object/function `dist`, -and a declared node_type (`:discrete` or `:continuous`). -""" function add_stochastic_vertex!( bn::BayesianNetwork{V,T}, name::V, dist::Any, - node_type::Symbol = :continuous; # default if not specified + node_type::Symbol = :continuous; # e.g. :discrete or :continuous is_observed::Bool = false )::T where {V,T} Graphs.add_vertex!(bn.graph) || return 0 @@ -66,9 +59,6 @@ function add_stochastic_vertex!( return id end -""" -Add a deterministic vertex (unused here, but for completeness). -""" function add_deterministic_vertex!(bn::BayesianNetwork{V,T}, name::V, f::F)::T where {T,V,F} Graphs.add_vertex!(bn.graph) || return 0 id = nv(bn.graph) @@ -82,9 +72,6 @@ function add_deterministic_vertex!(bn::BayesianNetwork{V,T}, name::V, f::F)::T w return id end -""" -Add a directed edge from `from` -> `to`. -""" function add_edge!(bn::BayesianNetwork{V,T}, from::V, to::V)::Bool where {T,V} from_id = bn.names_to_ids[from] to_id = bn.names_to_ids[to] @@ -92,7 +79,7 @@ function add_edge!(bn::BayesianNetwork{V,T}, from::V, to::V)::Bool where {T,V} end ############################################################################### -# 3) Create the 5-node chain, marking discrete/continuous explicitly +# 3) A 5-node chain, with discrete/continuous mix ############################################################################### logistic(x) = 1 / (1 + exp(-x)) @@ -123,16 +110,14 @@ function create_5node_network() end ############################################################################### -# 4) Parent Helpers: `inneighbors` instead of `predecessors` +# 4) Parent/Distribution Helpers ############################################################################### function parent_ids(bn::BayesianNetwork, node_id::Int) - # For a node_id, get all incoming edges. - return inneighbors(bn.graph, node_id) + return inneighbors(bn.graph, node_id) # replaced `predecessors` with `inneighbors` end function parent_values(bn::BayesianNetwork, node_id::Int) - # Retrieve the (already assigned) parent values in ascending ID order. pids = parent_ids(bn, node_id) sort!(pids) vals = Any[] @@ -146,37 +131,26 @@ function parent_values(bn::BayesianNetwork, node_id::Int) return vals end -""" -Returns the Distribution object for a node, calling its stored function if needed. -""" function get_distribution(bn::BayesianNetwork, node_id::Int)::Distribution stored = bn.distributions[node_id] if stored isa Distribution return stored elseif stored isa Function - pvals = parent_values(bn, node_id) # calls parent's values + pvals = parent_values(bn, node_id) # gather parent's assigned values return stored(pvals...) else - error("Node $node_id has invalid distribution entry (neither Distribution nor Function).") + error("Node $node_id has invalid distribution entry.") end end -""" -Check if a node is discrete by referencing the stored `node_types`. -We do NOT call get_distribution() here, which avoids the "Missing value for parent" error. -""" function is_discrete_node(bn::BayesianNetwork, node_id::Int) return bn.node_types[node_id] == :discrete end ############################################################################### -# 5) Summation & Log PDF Calculation +# 5) Logpdf Computation ############################################################################### -""" -Compute log-pdf of the current bn.values assignment. -If any parent's value is missing for a node that has a value, we return -Inf. -""" function compute_full_logpdf(bn::BayesianNetwork) logp = 0.0 for sid in bn.stochastic_ids @@ -200,49 +174,108 @@ function compute_full_logpdf(bn::BayesianNetwork) return logp end +############################################################################### +# 6) Naive Summation vs. DP Summation +############################################################################### +# We provide two ways to sum over discrete configurations: +# - sum_discrete_configurations: naive recursion +# - sum_discrete_configurations_dp: memoized recursion (DP) + """ -Naive enumeration over all unobserved discrete nodes in `discrete_ids`. -Multiply pdf(...) for each assignment, summing up to get total probability. +Naive recursion: +Enumerate all discrete node values for unobserved discrete nodes. """ -function sum_discrete_configurations(bn::BayesianNetwork, - discrete_ids::Vector{Int}, - idx::Int) +function sum_discrete_configurations( + bn::BayesianNetwork, + discrete_ids::Vector{Int}, + idx::Int +)::Float64 if idx > length(discrete_ids) - # base case: all discrete nodes assigned => evaluate logpdf for everything return exp( compute_full_logpdf(bn) ) else node_id = discrete_ids[idx] dist = get_distribution(bn, node_id) total_prob = 0.0 for val in support(dist) - bn.values[bn.names[node_id]] = val + bn.values[ bn.names[node_id] ] = val + # multiply by pdf(dist, val) total_prob += sum_discrete_configurations(bn, discrete_ids, idx+1) * pdf(dist, val) end - delete!(bn.values, bn.names[node_id]) # clean up + delete!(bn.values, bn.names[node_id]) + return total_prob + end +end + +""" +DP-based recursion: +Use a memo dictionary to store subproblem results. +Key = (idx, assigned_values_for_these_discrete_ids) + +We store the partial assignment of discrete_ids[1:idx-1], then skip redoing +the entire subtree if we've already computed it. +""" +function sum_discrete_configurations_dp( + bn::BayesianNetwork, + discrete_ids::Vector{Int}, + idx::Int, + memo::Dict{Any,Float64}, + assigned_vals::Vector{Any} +)::Float64 + # If we've assigned up to idx-1, the key is (idx, assigned_vals) + # Make a stable copy of assigned_vals for the memo key. + # assigned_vals only includes the discrete nodes *so far* in order. + key = (idx, deepcopy(assigned_vals)) + + if haskey(memo, key) + return memo[key] + end + + if idx > length(discrete_ids) + # base case => compute the logpdf + result = exp( compute_full_logpdf(bn) ) + memo[key] = result + return result + else + node_id = discrete_ids[idx] + dist = get_distribution(bn, node_id) + total_prob = 0.0 + # for each possible value in support, assign + recurse + for val in support(dist) + bn.values[ bn.names[node_id] ] = val + assigned_vals[idx] = val # record partial assignment + total_prob += sum_discrete_configurations_dp( + bn, discrete_ids, idx+1, memo, assigned_vals + ) * pdf(dist, val) + end + delete!(bn.values, bn.names[node_id]) + memo[key] = total_prob return total_prob end end ############################################################################### -# 6) Create a log_posterior function +# 7) create_log_posterior with DP option ############################################################################### -function create_log_posterior(bn::BayesianNetwork) +""" +Creates a log_posterior function that merges unobserved values + sums out +unobserved discrete nodes. If use_dp=true, we use the DP approach; else naive. +""" +function create_log_posterior(bn::BayesianNetwork; use_dp::Bool=false) function log_posterior(unobserved_values::Dict{Symbol,Float64}) # Save old BN state old_values = copy(bn.values) try - # Merge the unobserved values into bn.values - for (k,v) in unobserved_values + # Merge unobserved + for (k, v) in unobserved_values bn.values[k] = v end - # Identify unobserved, discrete nodes => must sum out + # Identify unobserved discrete IDs unobs_discrete_ids = Int[] for sid in bn.stochastic_ids if !bn.is_observed[sid] varname = bn.names[sid] - # If we haven't assigned a value for varname, it is unobserved if !haskey(bn.values, varname) && is_discrete_node(bn, sid) push!(unobs_discrete_ids, sid) end @@ -253,12 +286,25 @@ function create_log_posterior(bn::BayesianNetwork) # no discrete marginalization => direct logpdf return compute_full_logpdf(bn) else - # sum out the discrete ids - prob_sum = sum_discrete_configurations(bn, unobs_discrete_ids, 1) - return log(prob_sum) + # sum out discrete configurations + if use_dp + # DP approach + memo = Dict{Any,Float64}() + assigned_vals = Vector{Any}(undef, length(unobs_discrete_ids)) + for i in 1:length(unobs_discrete_ids) + assigned_vals[i] = nothing + end + prob_sum = sum_discrete_configurations_dp( + bn, unobs_discrete_ids, 1, memo, assigned_vals + ) + return log(prob_sum) + else + # naive recursion + prob_sum = sum_discrete_configurations(bn, unobs_discrete_ids, 1) + return log(prob_sum) + end end finally - # restore bn.values = old_values end end @@ -266,75 +312,223 @@ function create_log_posterior(bn::BayesianNetwork) end ############################################################################### -# 7) Evaluate the model for a set of observations & X1 values +# 8) Evaluate function ############################################################################### -function evaluate_model(bn::BayesianNetwork, obs::Dict{Symbol,Float64}, - X1_values, description::AbstractString) +function evaluate_model(bn::BayesianNetwork; + obs::Dict{Symbol,Float64}=Dict(), + X1_values=0.0:0.5:1.5, + description="", + use_dp::Bool=false +) println("\n=== $description ===") - println("Observations: ", obs) + println("Observations: $obs") + println("use_dp = $use_dp") - # Save old BN state - old_values = copy(bn.values) - old_observed = copy(bn.is_observed) + old_values = copy(bn.values) + old_obs = copy(bn.is_observed) try - # Condition on `obs` by storing them + marking is_observed + # Condition on obs for (k, v) in obs id = bn.names_to_ids[k] bn.values[k] = v bn.is_observed[id] = true end - # Build the log posterior function - log_post = create_log_posterior(bn) + # create log posterior + log_post = create_log_posterior(bn; use_dp=use_dp) - # Evaluate log posterior for each candidate X1 + # evaluate over X1 results = [(x1, log_post(Dict(:X1 => x1))) for x1 in X1_values] - # Convert to normalized posterior + # normalize max_lp = maximum(last.(results)) - normalized_post = [(x1, exp(lp - max_lp)) for (x1, lp) in results] + posterior = [(x1, exp(lp - max_lp)) for (x1, lp) in results] - # Print - for (x1, p) in normalized_post + # print + for (x1, p) in posterior @printf(" X1 = %.2f => normalized posterior = %.5f\n", x1, p) end + finally - # restore BN state bn.values = old_values - bn.is_observed = old_observed + bn.is_observed = old_obs end end ############################################################################### -# 8) Demonstration (Same 5-Node chain + 3 test cases) +# 9) Demonstration ############################################################################### +# (A) The Original 5-node chain model_5_nodes = create_5node_network() - X1_values = 0.0:0.5:1.5 -println("\n=== Running test cases on the 5-node model ===") +println("\n=== Running test cases on the 5-node chain ===") +# 1) Observing X4=1, X5=2 => marginalizing X2,X3 evaluate_model( model_5_nodes, - Dict(:X4 => 1.0, :X5 => 2.0), - X1_values, - "5-Node Model (X4=1.0, X5=2.0, marginalizing X2,X3)" + obs=Dict(:X4=>1.0, :X5=>2.0), + X1_values=X1_values, + description="5-Node Model (X4=1.0, X5=2.0, marginalizing X2,X3) [DP=FALSE]", + use_dp=false ) +# same scenario but use_dp=true evaluate_model( model_5_nodes, - Dict(:X5 => 2.0), - X1_values, - "5-Node Model (X5=2.0, marginalizing X2,X3,X4)" + obs=Dict(:X4=>1.0, :X5=>2.0), + X1_values=X1_values, + description="5-Node Model (X4=1.0, X5=2.0, marginalizing X2,X3) [DP=TRUE]", + use_dp=true ) +# 2) Observing only X5=2.0 => marginalize X2,X3,X4 evaluate_model( model_5_nodes, - Dict(:X2 => 1.0, :X3 => 1.0, :X4 => 1.0, :X5 => 2.0), - X1_values, - "5-Node Model (all observed: X2,X3,X4,X5)" + obs=Dict(:X5=>2.0), + X1_values=X1_values, + description="5-Node Model (X5=2.0, marginalizing X2,X3,X4) [DP=TRUE]", + use_dp=true ) +# 3) Observing X2=1, X3=1, X4=1, X5=2 => only X1 is unknown +evaluate_model( + model_5_nodes, + obs=Dict(:X2=>1.0, :X3=>1.0, :X4=>1.0, :X5=>2.0), + X1_values=X1_values, + description="5-Node Model (all observed except X1) [DP=TRUE]", + use_dp=true +) + + +############################################################################### +# 10) More Complex/Branching Example to Show DP Gains +############################################################################### + +""" +Construct a BN with multiple discrete nodes that share parents, +to demonstrate repeated subproblems for DP. + +Structure (simplified example): + X1 (discrete or continuous) + / \ + X2 X3 (both discrete, each depends on X1) + / + X4 (discrete, depends on X2 and X3) + \ + X5 (continuous, depends on X4) + +We'll artificially inflate the support of X2, X3, X4 to highlight DP benefits. +""" +function create_branching_network() + bn = BayesianNetwork{Symbol}() + + # Let's say X1 is discrete with bigger support (like 0..2) + # Could be continuous, but discrete might emphasize the DP. + x1_dist = Categorical([0.3, 0.4, 0.3]) # values in {1,2,3}, for example + add_stochastic_vertex!(bn, :X1, x1_dist, :discrete) + + # X2 depends on X1 => let each X1 value define different distribution + # e.g. X2 in {0,1,2} with varied probabilities + function x2_distfn(x1) + # x1 might be 1,2,3 => define some random param + if x1 == 1 + return Categorical([0.5, 0.3, 0.2]) + elseif x1 == 2 + return Categorical([0.2, 0.5, 0.3]) + else # x1==3 + return Categorical([0.1, 0.2, 0.7]) + end + end + add_stochastic_vertex!(bn, :X2, x2_distfn, :discrete) + add_edge!(bn, :X1, :X2) + + # X3 depends on X1 => similarly + function x3_distfn(x1) + if x1 == 1 + return Categorical([0.4, 0.6]) # 2 possible states + elseif x1 == 2 + return Categorical([0.7, 0.3]) + else # x1==3 + return Categorical([0.3, 0.7]) + end + end + add_stochastic_vertex!(bn, :X3, x3_distfn, :discrete) + add_edge!(bn, :X1, :X3) + + # X4 depends on X2, X3 => let's define a function with multi dimension + # Suppose X2 in {0,1,2}, X3 in {0,1}, so X4 in {0,1,2} again + function x4_distfn(x2, x3) + # just a made-up table + # x2=0, x3=0 => [0.2, 0.5, 0.3], etc. + # We'll do a big if-else block or indexing + if x2 == 0 && x3 == 0 + return Categorical([0.2, 0.5, 0.3]) + elseif x2 == 0 && x3 == 1 + return Categorical([0.1, 0.2, 0.7]) + elseif x2 == 1 && x3 == 0 + return Categorical([0.5, 0.3, 0.2]) + elseif x2 == 1 && x3 == 1 + return Categorical([0.3, 0.4, 0.3]) + elseif x2 == 2 && x3 == 0 + return Categorical([0.25, 0.25, 0.5]) + else # x2==2 && x3==1 + return Categorical([0.2, 0.2, 0.6]) + end + end + add_stochastic_vertex!(bn, :X4, x4_distfn, :discrete) + add_edge!(bn, :X2, :X4) + add_edge!(bn, :X3, :X4) + + # X5 depends on X4 => continuous + # e.g. if X4=2 => Normal(5,1), else if X4=1 => Normal(0,1), else X4=0 => Normal(-3,2) + function x5_distfn(x4) + if x4 == 0 + return Normal(-3,2) + elseif x4 == 1 + return Normal(0,1) + else + return Normal(5,1) + end + end + add_stochastic_vertex!(bn, :X5, x5_distfn, :continuous) + add_edge!(bn, :X4, :X5) + + return bn +end + +# Let's build & run a test to show how DP helps when enumerating +println("\n=== More Complex/Branching BN test to demonstrate DP gains ===") +branching_bn = create_branching_network() + +# Suppose we observe X5=4.2 => i.e. we know the continuous node, +# and must marginalize out (X1, X2, X3, X4) which are mostly discrete. +obs = Dict(:X5 => 4.2) +id5 = branching_bn.names_to_ids[:X5] +branching_bn.values[:X5] = 4.2 +branching_bn.is_observed[id5] = true + +# Evaluate log posterior with & without DP +function test_branching_bn(branching_bn) + # We'll define a dummy "X1_values" since X1 is discrete; let's pass an empty set + # because we only want the log posterior (not scanning over X1 in a grid). + log_post_naive = create_log_posterior(branching_bn; use_dp=false) + log_post_dp = create_log_posterior(branching_bn; use_dp=true) + + # Just call each function once with no extra unobserved_values + # to force the code to sum over X1, X2, X3, X4. + println(">>> Branching BN: Observing X5=4.2, no other nodes => big discrete sum <<<") + println("Naive approach ...") + @time naive_lp = log_post_naive(Dict()) # time macro to see performance + + println("DP approach ...") + @time dp_lp = log_post_dp(Dict()) + + println("Naive log posterior = $naive_lp, DP log posterior = $dp_lp\n") +end + +test_branching_bn(branching_bn) + println("\nDone.")