Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
naseweisssss committed Jan 14, 2025
1 parent 6e0a0f4 commit acaa456
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions src/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ function decondition!(bn::BayesianNetwork{V}, deconditioning_variables::Vector{V
return bn
end


"""
Add a stochastic vertex to the BayesianNetwork.
- `dist` can be a `Distribution` or a function returning a `Distribution`.
Expand All @@ -124,8 +123,8 @@ function add_stochastic_vertex!(
bn::BayesianNetwork{V,T},
name::V,
dist::Any,
node_type::Symbol = :continuous;
is_observed::Bool = false
node_type::Symbol=:continuous;
is_observed::Bool=false,
)::T where {V,T}
Graphs.add_vertex!(bn.graph) || return 0
id = nv(bn.graph)
Expand Down Expand Up @@ -161,7 +160,7 @@ Add a directed edge `from -> to` in the BayesianNetwork's graph.
"""
function add_edge!(bn::BayesianNetwork{V,T}, from::V, to::V)::Bool where {T,V}
from_id = bn.names_to_ids[from]
to_id = bn.names_to_ids[to]
to_id = bn.names_to_ids[to]
return Graphs.add_edge!(bn.graph, from_id, to_id)
end

Expand Down Expand Up @@ -398,7 +397,7 @@ function compute_full_logpdf(bn::BayesianNetwork)
end
end
dist = get_distribution(bn, sid)
val = bn.values[varname]
val = bn.values[varname]
lpdf = logpdf(dist, val)
if isinf(lpdf)
return -Inf
Expand All @@ -419,19 +418,18 @@ Enumerate all discrete node values for unobserved discrete nodes.
Returns a *probability sum*, i.e. sum over exp(logpdf).
"""
function sum_discrete_configurations(
bn::BayesianNetwork,
discrete_ids::Vector{Int},
idx::Int
bn::BayesianNetwork, discrete_ids::Vector{Int}, idx::Int
)::Float64
if idx > length(discrete_ids)
return exp( compute_full_logpdf(bn) )
return exp(compute_full_logpdf(bn))
else
node_id = discrete_ids[idx]
dist = get_distribution(bn, node_id)
total_prob = 0.0
for val in support(dist)
bn.values[ bn.names[node_id] ] = val
total_prob += sum_discrete_configurations(bn, discrete_ids, idx+1) * pdf(dist, val)
bn.values[bn.names[node_id]] = val
total_prob +=
sum_discrete_configurations(bn, discrete_ids, idx + 1) * pdf(dist, val)
end
delete!(bn.values, bn.names[node_id])
return total_prob
Expand Down

0 comments on commit acaa456

Please sign in to comment.