Skip to content

Commit

Permalink
Merge branch 'py-why:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
EvieQ01 authored Feb 6, 2025
2 parents 509ebe7 + 9dc0365 commit 165ed2d
Show file tree
Hide file tree
Showing 19 changed files with 1,456 additions and 33 deletions.
8 changes: 4 additions & 4 deletions causallearn/score/LocalScoreFunctionClass.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
self.parameters = parameters
self.score_cache = {}

if self.local_score_fun == local_score_BIC_from_cov:
if self.local_score_fun.__name__ == 'local_score_BIC_from_cov':
self.cov = np.cov(self.data.T)
self.n = self.data.shape[0]

Expand All @@ -40,15 +40,15 @@ def score(self, i: int, PAi: List[int]) -> float:
hash_key = tuple(sorted(PAi))

if not self.score_cache[i].__contains__(hash_key):
if self.local_score_fun == local_score_BIC_from_cov:
if self.local_score_fun.__name__ == 'local_score_BIC_from_cov':
self.score_cache[i][hash_key] = self.local_score_fun((self.cov, self.n), i, PAi, self.parameters)
else:
self.score_cache[i][hash_key] = self.local_score_fun(self.data, i, PAi, self.parameters)

return self.score_cache[i][hash_key]

def score_nocache(self, i: int, PAi: List[int]) -> float:
if self.local_score_fun == local_score_BIC_from_cov:
if self.local_score_fun.__name__ == 'local_score_BIC_from_cov':
return self.local_score_fun((self.cov, self.n), i, PAi, self.parameters)
else:
return self.local_score_fun(self.data, i, PAi, self.parameters)
return self.local_score_fun(self.data, i, PAi, self.parameters)
20 changes: 10 additions & 10 deletions causallearn/search/PermutationBased/BOSS.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def boss(
if n < p:
warnings.warn("The number of features is much larger than the sample size!")

if score_func == "local_score_CV_general":
if score_func == "local_score_CV_general":
# % k-fold negative cross validated likelihood based on regression in RKHS
if parameters is None:
parameters = {
Expand All @@ -63,13 +63,13 @@ def boss(
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_cv_general, parameters=parameters
)
elif score_func == "local_score_marginal_general":
elif score_func == "local_score_marginal_general":
# negative marginal likelihood based on regression in RKHS
parameters = {}
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_marginal_general, parameters=parameters
)
elif score_func == "local_score_CV_multi":
elif score_func == "local_score_CV_multi":
# k-fold negative cross validated likelihood based on regression in RKHS
# for data with multi-variate dimensions
if parameters is None:
Expand All @@ -83,7 +83,7 @@ def boss(
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_cv_multi, parameters=parameters
)
elif score_func == "local_score_marginal_multi":
elif score_func == "local_score_marginal_multi":
# negative marginal likelihood based on regression in RKHS
# for data with multi-variate dimensions
if parameters is None:
Expand All @@ -93,22 +93,22 @@ def boss(
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_marginal_multi, parameters=parameters
)
elif score_func == "local_score_BIC":
elif score_func == "local_score_BIC":
# SEM BIC score
warnings.warn("Please use 'local_score_BIC_from_cov' instead")
warnings.warn("Using 'local_score_BIC_from_cov' instead for efficiency")
if parameters is None:
parameters = {"lambda_value": 2}
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_BIC, parameters=parameters
data=X, local_score_fun=local_score_BIC_from_cov, parameters=parameters
)
elif score_func == "local_score_BIC_from_cov":
elif score_func == "local_score_BIC_from_cov":
# SEM BIC score
if parameters is None:
parameters = {"lambda_value": 2}
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_BIC_from_cov, parameters=parameters
)
elif score_func == "local_score_BDeu":
elif score_func == "local_score_BDeu":
# BDeu score
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_BDeu, parameters=None
Expand Down Expand Up @@ -204,4 +204,4 @@ def better_mutation(v, order, gsts):
order.remove(v)
order.insert(best - int(best > i), v)

return True
return True
20 changes: 10 additions & 10 deletions causallearn/search/PermutationBased/GRaSP.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
local_score_marginal_general,
local_score_marginal_multi,
)
from causallearn.search.PermutationBased.gst import GST;
from causallearn.search.PermutationBased.gst import GST
from causallearn.score.LocalScoreFunctionClass import LocalScoreClass
from causallearn.utils.DAG2CPDAG import dag2cpdag

Expand Down Expand Up @@ -111,7 +111,7 @@ def grasp(
if n < p:
warnings.warn("The number of features is much larger than the sample size!")

if score_func == "local_score_CV_general":
if score_func == "local_score_CV_general":
# k-fold negative cross validated likelihood based on regression in RKHS
if parameters is None:
parameters = {
Expand All @@ -127,7 +127,7 @@ def grasp(
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_marginal_general, parameters=parameters
)
elif score_func == "local_score_CV_multi":
elif score_func == "local_score_CV_multi":
# k-fold negative cross validated likelihood based on regression in RKHS
# for data with multi-variate dimensions
if parameters is None:
Expand All @@ -141,7 +141,7 @@ def grasp(
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_cv_multi, parameters=parameters
)
elif score_func == "local_score_marginal_multi":
elif score_func == "local_score_marginal_multi":
# negative marginal likelihood based on regression in RKHS
# for data with multi-variate dimensions
if parameters is None:
Expand All @@ -151,22 +151,22 @@ def grasp(
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_marginal_multi, parameters=parameters
)
elif score_func == "local_score_BIC":
elif score_func == "local_score_BIC":
# SEM BIC score
warnings.warn("Please use 'local_score_BIC_from_cov' instead")
warnings.warn("Using 'local_score_BIC_from_cov' instead for efficiency")
if parameters is None:
parameters = {"lambda_value": 2}
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_BIC, parameters=parameters
data=X, local_score_fun=local_score_BIC_from_cov, parameters=parameters
)
elif score_func == "local_score_BIC_from_cov":
elif score_func == "local_score_BIC_from_cov":
# SEM BIC score
if parameters is None:
parameters = {"lambda_value": 2}
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_BIC_from_cov, parameters=parameters
)
elif score_func == "local_score_BDeu":
elif score_func == "local_score_BDeu":
# BDeu score
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_BDeu, parameters=None
Expand Down Expand Up @@ -204,7 +204,7 @@ def grasp(
sys.stdout.flush()

runtime = time.perf_counter() - runtime

if verbose:
sys.stdout.write("\nGRaSP completed in: %.2fs \n" % runtime)
sys.stdout.flush()
Expand Down
212 changes: 212 additions & 0 deletions causallearn/search/ScoreBased/CALM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import numpy as np
import torch
import torch.nn as nn
from causallearn.utils.MarkovNetwork.iamb import iamb_markov_network
from causallearn.utils.CALMUtils import *
from causallearn.graph.GeneralGraph import GeneralGraph
from causallearn.graph.GraphNode import GraphNode
from typing import Any, Dict
from scipy.special import expit as sigmoid

torch.set_default_dtype(torch.double)

def calm(
X: np.ndarray,
lambda1: float = 0.005,
alpha: float = 0.01,
tau: float = 0.5,
rho_init: float = 1e-5,
rho_mult: float = 3,
htol: float = 1e-8,
subproblem_iter: int = 40000,
standardize: bool = False,
device: str = 'cpu'
) -> Dict[str, Any]:
"""
Perform the CALM (Continuous and Acyclicity-constrained L0-penalized likelihood with estimated Moral graph) algorithm.
Parameters
----------
X : numpy.ndarray
Input dataset of shape (n, d), where n is the number of samples,
and d is the number of variables.
lambda1 : float, optional
Coefficient for the approximated L0 penalty, which encourages sparsity in the learned graph. Default is 0.005.
alpha : float, optional
Significance level for conditional independence tests. Default is 0.01.
tau : float, optional
Temperature parameter for the Gumbel-Sigmoid. Default is 0.5.
rho_init : float, optional
Initial value of the penalty parameter for the acyclicity constraint. Default is 1e-5.
rho_mult : float, optional
Multiplication factor for rho in each iteration. Default is 3.
htol : float, optional
Tolerance level for acyclicity constraint. Default is 1e-8.
subproblem_iter : int, optional
Number of iterations for subproblem optimization. Default is 40000.
standardize : bool, optional
Whether to standardize the input data (mean=0, variance=1). Default is False.
device : str, optional
The device to use for computation ('cpu' or 'cuda'). Default is 'cpu'.
Returns
-------
Record : dict
A dictionary containing:
- Record['G']: learned causal graph, a DAG, where: Record['G'].graph[j,i]=1 and Record['G'].graph[i,j]=-1 indicates i --> j.
- Record['B_weighted']: weighted adjacency matrix of the learned causal graph.
"""

d = X.shape[1]
if standardize:
mean_X = np.mean(X, axis=0, keepdims=True)
std_X = np.std(X, axis=0, keepdims=True)
X = (X - mean_X) / std_X
else:
X = X - np.mean(X, axis=0, keepdims=True)

# Compute the data covariance matrix
cov_emp = np.cov(X.T, bias=True)

# Learn the moral graph using the IAMB Markov network
moral_mask, _ = iamb_markov_network(X, alpha=alpha)

# Initialize and run the CalmModel
device = torch.device(device)
cov_emp = torch.from_numpy(cov_emp).to(device)
moral_mask = torch.from_numpy(moral_mask).float().to(device)

model = CalmModel(d, moral_mask, tau=tau, lambda1=lambda1).to(device)

# Optimization loop
rho = rho_init
for _ in range(100):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for _ in range(subproblem_iter):
optimizer.zero_grad()
loss = model.compute_loss(cov_emp, rho)
loss.backward(retain_graph=True)
optimizer.step()

with torch.no_grad():
B_logit_copy = model.B_logit.detach().clone()
B_logit_copy[model.moral_mask == 0] = float('-inf')
h_sigmoid = model.compute_h(torch.sigmoid(B_logit_copy / model.tau))

rho *= rho_mult
if h_sigmoid.item() <= htol or rho > 1e+16:
break

# Extract the final binary and weighted adjacency matrices
params_est = model.get_params()
B_bin, B_weighted = params_est['B_bin'], params_est['B']

node_names = [("X%d" % (i + 1)) for i in range(d)]
nodes = [GraphNode(name) for name in node_names]
G = GeneralGraph(nodes)

# Add edges to the GeneralGraph based on B_bin
for i in range(d):
for j in range(d):
if B_bin[i, j] == 1:
G.add_directed_edge(nodes[i], nodes[j])

Record = {
"G": G, # GeneralGraph object representing the learned causal graph, a DAG
"B_weighted": B_weighted # Weighted adjacency matrix of the learned graph
}

return Record

class CalmModel(nn.Module):
"""
The CALM model
Parameters
----------
d : int
Number of variables/nodes in the graph.
moral_mask : torch.Tensor
Binary mask representing the moral graph structure, used to restrict possible edges.
tau : float, optional
Temperature parameter for the Gumbel-Sigmoid sampling, controlling the sparsity approximation. Default is 0.5.
lambda1 : float, optional
Coefficient for the approximated L0 penalty (sparsity term). Default is 0.005.
"""
def __init__(self, d, moral_mask, tau=0.5, lambda1=0.005):
super(CalmModel, self).__init__()
self.d = d
self.moral_mask = moral_mask
self.tau = tau
self.lambda1 = lambda1
self._init_params()

def _init_params(self):
"""Initialize parameters"""
self.B_param = nn.Parameter(
torch.FloatTensor(self.d, self.d).uniform_(-0.001, 0.001).to(self.moral_mask.device)
)
self.B_logit = nn.Parameter(
torch.zeros(self.d, self.d).to(self.moral_mask.device)
)

def sample_mask(self):
"""
Samples a binary mask B_mask based on the Gumbel-Sigmoid approximation.
Applies the moral graph mask to restrict possible edges.
"""
B_mask = gumbel_sigmoid(self.B_logit, tau=self.tau)
B_mask = B_mask * self.moral_mask
return B_mask

@torch.no_grad()
def get_params(self):
"""
Returns the estimated adjacency matrix B_bin (binary) and B (weighted), thresholding at 0.5.
"""
threshold = 0.5
B_param = self.B_param.cpu().detach().numpy()
B_logit = self.B_logit.cpu().detach().numpy()
B_logit[self.moral_mask.cpu().numpy() == 0] = float('-inf')
B_bin = sigmoid(B_logit / self.tau)
B_bin[B_bin < threshold] = 0
B_bin[B_bin >= threshold] = 1
B = B_bin * B_param
params = {'B': B, 'B_bin': B_bin}
return params

def compute_likelihood(self, B, cov_emp):
"""
Computes the likelihood-based objective function for non-equal noise variance (NV) assumption.
"""
I = torch.eye(self.d, device=self.B_param.device)
residuals = torch.diagonal((I - B).T @ cov_emp @ (I - B))
likelihood = 0.5 * torch.sum(torch.log(residuals)) - torch.linalg.slogdet(I - B)[1]
return likelihood

def compute_sparsity(self, B_mask):
"""
Computes the sparsity penalty (approximated L0 penalty) by summing the binary entries in B_mask.
"""
return B_mask.sum()

def compute_h(self, B_mask):
"""
Computes the DAG constraint term, adapted from the DAG constraint formulation
in Yu et al. (2019).
"""
return torch.trace(matrix_poly(B_mask, self.d, self.B_param.device)) - self.d

def compute_loss(self, cov_emp, rho):
"""
Combines likelihood, approximated L0 penalty (sparsity), and DAG constraint terms into the final loss function.
"""
B_mask = self.sample_mask()
B = B_mask * self.B_param
likelihood = self.compute_likelihood(B, cov_emp)
sparsity = self.lambda1 * self.compute_sparsity(B_mask)
h = self.compute_h(B_mask)
loss = likelihood + sparsity + 0.5 * rho * h**2
return loss


2 changes: 2 additions & 0 deletions causallearn/search/ScoreBased/ExactSearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ def bic_score_node(X, i, structure):
b=X[:, i],
rcond=None)
bic = n * np.log(residual / n) + len(structure) * np.log(n)
if bic.size == 0:
return NEGINF # Return negative infinity if bic is empty
return bic.item()


Expand Down
Loading

0 comments on commit 165ed2d

Please sign in to comment.