Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add two variants of the KCI test #202

Merged
merged 16 commits into from
Nov 5, 2024
534 changes: 534 additions & 0 deletions causallearn/utils/FastKCI/FastKCI.py

Large diffs are not rendered by default.

Empty file.
403 changes: 403 additions & 0 deletions causallearn/utils/RCIT/RCIT.py

Large diffs are not rendered by default.

Empty file.
58 changes: 56 additions & 2 deletions causallearn/utils/cit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from scipy.stats import chi2, norm

from causallearn.utils.KCI.KCI import KCI_CInd, KCI_UInd
from causallearn.utils.FastKCI.FastKCI import FastKCI_CInd, FastKCI_UInd
from causallearn.utils.RCIT.RCIT import RCIT as RCIT_CInd
from causallearn.utils.RCIT.RCIT import RIT as RCIT_UInd
from causallearn.utils.PCUtils import Helper

CONST_BINCOUNT_UNIQUE_THRESHOLD = 1e5
Expand All @@ -13,6 +16,8 @@
mv_fisherz = "mv_fisherz"
mc_fisherz = "mc_fisherz"
kci = "kci"
rcit = "rcit"
fastkci = "fastkci"
chisq = "chisq"
gsq = "gsq"
d_separation = "d_separation"
Expand All @@ -23,15 +28,19 @@ def CIT(data, method='fisherz', **kwargs):
Parameters
----------
data: numpy.ndarray of shape (n_samples, n_features)
method: str, in ["fisherz", "mv_fisherz", "mc_fisherz", "kci", "chisq", "gsq"]
kwargs: placeholder for future arguments, or for KCI specific arguments now
method: str, in ["fisherz", "mv_fisherz", "mc_fisherz", "kci", "rcit", "fastkci", "chisq", "gsq"]
kwargs: placeholder for future arguments, or for KCI, FastKCI or RCIT specific arguments now
TODO: utimately kwargs should be replaced by explicit named parameters.
check https://github.com/cmu-phil/causal-learn/pull/62#discussion_r927239028
'''
if method == fisherz:
return FisherZ(data, **kwargs)
elif method == kci:
return KCI(data, **kwargs)
elif method == fastkci:
return FastKCI(data, **kwargs)
elif method == rcit:
return RCIT(data, **kwargs)
elif method in [chisq, gsq]:
return Chisq_or_Gsq(data, method_name=method, **kwargs)
elif method == mv_fisherz:
Expand All @@ -43,6 +52,7 @@ def CIT(data, method='fisherz', **kwargs):
else:
raise ValueError("Unknown method: {}".format(method))


class CIT_Base(object):
# Base class for CIT, contains basic operations for input check and caching, etc.
def __init__(self, data, cache_path=None, **kwargs):
Expand Down Expand Up @@ -193,6 +203,50 @@ def __call__(self, X, Y, condition_set=None):
self.pvalue_cache[cache_key] = p
return p

class FastKCI(CIT_Base):
def __init__(self, data, **kwargs):
super().__init__(data, **kwargs)
kci_ui_kwargs = {k: v for k, v in kwargs.items() if k in
['K', 'J', 'alpha']}
kci_ci_kwargs = {k: v for k, v in kwargs.items() if k in
['K', 'J', 'alpha', 'use_gp']}
self.check_cache_method_consistent(
'kci', hashlib.md5(json.dumps(kci_ci_kwargs, sort_keys=True).encode('utf-8')).hexdigest())
self.assert_input_data_is_valid()
self.kci_ui = FastKCI_UInd(**kci_ui_kwargs)
self.kci_ci = FastKCI_CInd(**kci_ci_kwargs)

def __call__(self, X, Y, condition_set=None):
# Kernel-based conditional independence test.
Xs, Ys, condition_set, cache_key = self.get_formatted_XYZ_and_cachekey(X, Y, condition_set)
if cache_key in self.pvalue_cache: return self.pvalue_cache[cache_key]
p = self.kci_ui.compute_pvalue(self.data[:, Xs], self.data[:, Ys])[0] if len(condition_set) == 0 else \
self.kci_ci.compute_pvalue(self.data[:, Xs], self.data[:, Ys], self.data[:, condition_set])[0]
self.pvalue_cache[cache_key] = p
return p

class RCIT(CIT_Base):
def __init__(self, data, **kwargs):
super().__init__(data, **kwargs)
rit_kwargs = {k: v for k, v in kwargs.items() if k in
['approx']}
rcit_kwargs = {k: v for k, v in kwargs.items() if k in
['approx', 'num_f', 'num_f2', 'rcit']}
self.check_cache_method_consistent(
'kci', hashlib.md5(json.dumps(rcit_kwargs, sort_keys=True).encode('utf-8')).hexdigest())
self.assert_input_data_is_valid()
self.rit = RCIT_UInd(**rit_kwargs)
self.rcit = RCIT_CInd(**rcit_kwargs)

def __call__(self, X, Y, condition_set=None):
# Kernel-based conditional independence test.
Xs, Ys, condition_set, cache_key = self.get_formatted_XYZ_and_cachekey(X, Y, condition_set)
if cache_key in self.pvalue_cache: return self.pvalue_cache[cache_key]
p = self.rit.compute_pvalue(self.data[:, Xs], self.data[:, Ys])[0] if len(condition_set) == 0 else \
self.rcit.compute_pvalue(self.data[:, Xs], self.data[:, Ys], self.data[:, condition_set])[0]
self.pvalue_cache[cache_key] = p
return p

class Chisq_or_Gsq(CIT_Base):
def __init__(self, data, method_name, **kwargs):
def _unique(column):
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
'matplotlib',
'networkx',
'pydot',
'tqdm'
'tqdm',
'momentchi2'
],
url='https://github.com/py-why/causal-learn',
packages=setuptools.find_packages(),
Expand Down
36 changes: 36 additions & 0 deletions tests/TestCIT_FastKCI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest

import numpy as np

import causallearn.utils.cit as cit


class TestCIT_FastKCI(unittest.TestCase):
def test_Gaussian_dist(self):
np.random.seed(10)
X = np.random.randn(1200, 1)
X_prime = np.random.randn(1200, 1)
Y = X + 0.5 * np.random.randn(1200, 1)
Z = Y + 0.5 * np.random.randn(1200, 1)
data = np.hstack((X, X_prime, Y, Z))

pvalue01 = []
pvalue03 = []
pvalue032 = []
for K in [3, 10]:
for J in [8, 16]:
for use_gp in [True, False]:
cit_CIT = cit.CIT(data, 'fastkci', K=K, J=J, use_gp=use_gp)
pvalue01.append(round(cit_CIT(0, 1), 4))
pvalue03.append(round(cit_CIT(0, 3), 4))
pvalue032.append(round(cit_CIT(0, 3, {2}), 4))

pvalue01 = np.array(pvalue01)
pvalue03 = np.array(pvalue03)
pvalue032 = np.array(pvalue032)
self.assertTrue(np.all((0.0 <= pvalue01) & (pvalue01 <= 1.0)),
"pvalue01 contains invalid values")
self.assertTrue(np.all((0.0 <= pvalue03) & (pvalue03 <= 1.0)),
"pvalue03 contains invalid values")
self.assertTrue(np.all((0.0 <= pvalue032) & (pvalue032 <= 1.0)),
"pvalue032 contains invalid values")
38 changes: 38 additions & 0 deletions tests/TestCIT_RCIT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import unittest

import numpy as np

import causallearn.utils.cit as cit


class TestCIT_RCIT(unittest.TestCase):
def test_Gaussian_dist(self):
np.random.seed(10)
X = np.random.randn(300, 1)
X_prime = np.random.randn(300, 1)
Y = X + 0.5 * np.random.randn(300, 1)
Z = Y + 0.5 * np.random.randn(300, 1)
data = np.hstack((X, X_prime, Y, Z))

pvalue01 = []
pvalue03 = []
pvalue032 = []
for approx in ["lpd4", "hbe", "gamma", "chi2", "perm"]:
for num_f in [50, 100]:
for num_f2 in [5, 10]:
for rcit in [True, False]:
cit_CIT = cit.CIT(data, 'rcit', approx=approx, num_f=num_f,
num_f2=num_f2, rcit=rcit)
pvalue01.append(round(cit_CIT(0, 1), 4))
pvalue03.append(round(cit_CIT(0, 3), 4))
pvalue032.append(round(cit_CIT(0, 3, {2}), 4))

pvalue01 = np.array(pvalue01)
pvalue03 = np.array(pvalue03)
pvalue032 = np.array(pvalue032)
self.assertTrue(np.all((0.0 <= pvalue01) & (pvalue01 <= 1.0)),
"pvalue01 contains invalid values")
self.assertTrue(np.all((0.0 <= pvalue03) & (pvalue03 <= 1.0)),
"pvalue03 contains invalid values")
self.assertTrue(np.all((0.0 <= pvalue032) & (pvalue032 <= 1.0)),
"pvalue032 contains invalid values")
Loading