Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Surrogate Assisted Causal Testing #250

Merged
merged 60 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
e13ed5d
splitting linear regression into linear and polynomial estimators
rsomers1998 Nov 27, 2023
81d5b59
black
rsomers1998 Nov 27, 2023
5d56060
Outline of testing approach
rsomers1998 Nov 27, 2023
ea2bc52
cleanup + search algo interface
rsomers1998 Nov 27, 2023
88fb2fe
Testing approach outline with object interfaces
rsomers1998 Nov 28, 2023
35fe098
basic surrogate + fitness function generation
rsomers1998 Nov 28, 2023
d54f6e2
example code for simulator interfacing
rsomers1998 Nov 28, 2023
ab169a0
GA search base class
rsomers1998 Nov 28, 2023
e94dc6e
metadata pulled from DAG for inclusion in testing and their expected …
rsomers1998 Nov 28, 2023
b5a0f8c
Minimal working version for CauSAT + minor changes to identification …
rsomers1998 Nov 29, 2023
4a3c807
Working version, runtime GA constraints + GA configuration
rsomers1998 Nov 29, 2023
e93d59c
formatting
rsomers1998 Nov 29, 2023
0243472
moved examples and implemented multithreading
rsomers1998 Dec 1, 2023
5192284
Added contradiction data to output + black formatting
rsomers1998 Dec 4, 2023
7c5e84b
Additional return info + fixed multithreading
rsomers1998 Dec 6, 2023
727fb6d
updated aps case study for multithreading
rsomers1998 Dec 6, 2023
8b0987e
surrogate assisted code moved to seperate package + updating example …
rsomers1998 Dec 12, 2023
7500e12
Remove examples
rsomers1998 Dec 14, 2023
d5c70e5
Abstract base classes, formatting and other comments
rsomers1998 Dec 14, 2023
5a41b1b
Added fitness function comment + removed multiprocessing code
rsomers1998 Dec 15, 2023
6d04385
Linting for estimators.py
christopher-wild Dec 15, 2023
0b8c38f
Make remove_hidden_adjustment_sets method static
christopher-wild Dec 15, 2023
219b5d5
reformat & black causal_dag.py
christopher-wild Dec 15, 2023
bd0c44f
docstrings for causal_dag.py
christopher-wild Dec 15, 2023
e3f474e
linting + black for surrogate_search_algorithms.py
christopher-wild Dec 19, 2023
e3ad52d
Docstrings
christopher-wild Dec 21, 2023
e244470
Add: initial unit tests for causal surrogate
f-allian Dec 22, 2023
162e28d
Revert "Add: initial unit tests for causal surrogate"
f-allian Dec 22, 2023
df20a49
Included a condition for invalid data returned
rsomers1998 Jan 2, 2024
e66fc6e
Renamed PolynomialRegressionEstimator to CubicSplineRegressionEstimator
rsomers1998 Jan 3, 2024
7c2d01d
Merge remote-tracking branch 'origin/surrogateassisted' into surrogat…
christopher-wild Jan 3, 2024
051861e
function & class strings for ABC classes
christopher-wild Jan 3, 2024
7692084
Add all docstrings to causal_surrogate_assisted.py
christopher-wild Jan 3, 2024
1e22537
remaining linting + black for causal_surrogate_assisted.py
christopher-wild Jan 3, 2024
2d473e5
Move functionality into static method called create_gene_types
christopher-wild Jan 4, 2024
055b064
Add type parameter to configuration variable
christopher-wild Jan 4, 2024
94ce3c9
Ignore cell-var-from-loop for surrogate_search_algorithms.py
christopher-wild Jan 4, 2024
2289ae4
black
christopher-wild Jan 4, 2024
c8f9dd3
Remove generic exception
christopher-wild Jan 4, 2024
49b80a7
Formalise ABCs
christopher-wild Jan 4, 2024
20cad36
Pylinting + updating doc string and typing
rsomers1998 Jan 4, 2024
5df6b9d
Pygad signature comment
rsomers1998 Jan 4, 2024
eeda6e7
Updated doc string max executions
rsomers1998 Jan 4, 2024
27c737a
Basic tests to stop tests failing
rsomers1998 Jan 4, 2024
8c37082
Adding tests
rsomers1998 Jan 15, 2024
2e97af1
Updated dependencies
rsomers1998 Jan 15, 2024
ae8b86c
Updated dependencies
rsomers1998 Jan 15, 2024
c8ebb29
Test coverage
rsomers1998 Jan 15, 2024
0c852d3
Linting
rsomers1998 Jan 15, 2024
ee62aa0
Fixed test
rsomers1998 Jan 15, 2024
0a816a3
Fixed test
rsomers1998 Jan 15, 2024
cc0485f
Fixed bug in fitness function scope
rsomers1998 Jan 17, 2024
04b377d
linting
rsomers1998 Jan 17, 2024
cd3b41d
Merge remote-tracking branch 'origin/surrogateassisted' into surrogat…
f-allian Jan 17, 2024
743b130
Fix: linting and added __init__.py to surrogate
f-allian Jan 17, 2024
5e5055d
Fix: linting error
f-allian Jan 17, 2024
996d013
Fix: linting error #2
f-allian Jan 17, 2024
eaa4360
Fix: linting error #3
f-allian Jan 17, 2024
d432236
Fix: supress lintings locally
f-allian Jan 26, 2024
48fd185
Merge branch 'main' into surrogateassisted
christopher-wild Jan 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions causal_testing/specification/causal_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def list_all_min_sep(
# 7. Check that there exists at least one neighbour of the treatment nodes that is not in the outcome node set
if treatment_node_set_neighbours.difference(outcome_node_set):
# 7.1. If so, sample a random node from the set of treatment nodes' neighbours not in the outcome node set
node = set(sample(treatment_node_set_neighbours.difference(outcome_node_set), 1))
node = set(sample(sorted(treatment_node_set_neighbours.difference(outcome_node_set)), 1))

# 7.2. Add this node to the treatment node set and recurse (left branch)
yield from list_all_min_sep(
Expand Down Expand Up @@ -125,7 +125,6 @@ def close_separator(


class CausalDAG(nx.DiGraph):

"""A causal DAG is a directed acyclic graph in which nodes represent random variables and edges represent causality
between a pair of random variables. We implement a CausalDAG as a networkx DiGraph with an additional check that
ensures it is acyclic. A CausalDAG must be specified as a dot file.
Expand Down Expand Up @@ -500,11 +499,20 @@ def depends_on_outputs(self, node: Node, scenario: Scenario) -> bool:
return True
return any((self.depends_on_outputs(n, scenario) for n in self.graph.predecessors(node)))

def identification(self, base_test_case: BaseTestCase):
@staticmethod
def remove_hidden_adjustment_sets(minimal_adjustment_sets: list[str], scenario: Scenario):
"""Remove variables labelled as hidden from adjustment set(s)
:param minimal_adjustment_sets: list of minimal adjustment set(s) to have hidden variables removed from
:param scenario: The modelling scenario which informs the variables that are hidden
"""
return [adj for adj in minimal_adjustment_sets if all(not scenario.variables.get(x).hidden for x in adj)]

def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None):
"""Identify and return the minimum adjustment set

:param base_test_case: A base test case instance containing the outcome_variable and the
treatment_variable required for identification.
:param scenario: The modelling scenario relating to the tests
:return minimal_adjustment_set: The smallest set of variables which can be adjusted for to obtain a causal
estimate as opposed to a purely associational estimate.
"""
Expand All @@ -520,6 +528,12 @@ def identification(self, base_test_case: BaseTestCase):
else:
raise ValueError("Causal effect should be 'total' or 'direct'")

if scenario is not None:
minimal_adjustment_sets = self.remove_hidden_adjustment_sets(minimal_adjustment_sets, scenario)

if len(minimal_adjustment_sets) == 0:
return set()

minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
return minimal_adjustment_set

Expand Down
Empty file.
142 changes: 142 additions & 0 deletions causal_testing/surrogate/causal_surrogate_assisted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Module containing classes to define and run causal surrogate assisted test cases"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable

from causal_testing.data_collection.data_collector import ObservationalDataCollector
from causal_testing.specification.causal_specification import CausalSpecification
from causal_testing.testing.base_test_case import BaseTestCase
from causal_testing.testing.estimators import CubicSplineRegressionEstimator


@dataclass
rsomers1998 marked this conversation as resolved.
Show resolved Hide resolved
class SimulationResult:
"""Data class holding the data and result metadata of a simulation"""

data: dict
fault: bool
relationship: str


class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods
"""Class to be inherited with the search algorithm consisting of a search function and the fitness function of the
space to be searched"""

@abstractmethod
def search(
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
) -> list:
"""Function which implements a search routine which searches for the optimal fitness value for the specified
scenario
:param surrogate_models: The surrogate models to be searched
:param specification: The Causal Specification (combination of Scenario and Causal Dag)"""


class Simulator(ABC):
"""Class to be inherited with Simulator specific functions to start, shutdown and run the simulation with the give
config file"""

@abstractmethod
def startup(self, **kwargs):
"""Function that when run, initialises and opens the Simulator"""

@abstractmethod
def shutdown(self, **kwargs):
"""Function to safely exit and shutdown the Simulator"""

@abstractmethod
def run_with_config(self, configuration: dict) -> SimulationResult:
"""Run the simulator with the given configuration and return the results in the structure of a
SimulationResult
:param configuration: The configuration required to initialise the Simulation
:return: Simulation results in the structure of the SimulationResult data class"""


class CausalSurrogateAssistedTestCase:
"""A class representing a single causal surrogate assisted test case."""

def __init__(
self,
specification: CausalSpecification,
search_algorithm: SearchAlgorithm,
simulator: Simulator,
):
self.specification = specification
self.search_algorithm = search_algorithm
self.simulator = simulator

def execute(
self,
data_collector: ObservationalDataCollector,
max_executions: int = 200,
rsomers1998 marked this conversation as resolved.
Show resolved Hide resolved
custom_data_aggregator: Callable[[dict, dict], dict] = None,
):
"""For this specific test case, a search algorithm is used to find the most contradictory point in the input
space which is, therefore, most likely to indicate incorrect behaviour. This cadidate test case is run against
the simulator, checked for faults and the result returned with collected data
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
:param max_executions: Maximum number of simulator executions before exiting the search
:param custom_data_aggregator:
:return: tuple containing SimulationResult or str, execution number and collected data"""
data_collector.collect_data()

for i in range(max_executions):
surrogate_models = self.generate_surrogates(self.specification, data_collector)
candidate_test_case, _, surrogate = self.search_algorithm.search(surrogate_models, self.specification)

self.simulator.startup()
test_result = self.simulator.run_with_config(candidate_test_case)
self.simulator.shutdown()

if custom_data_aggregator is not None:
if data_collector.data is not None:
data_collector.data = custom_data_aggregator(data_collector.data, test_result.data)
else:
data_collector.data = data_collector.data.append(test_result.data, ignore_index=True)

if test_result.fault:
print(
f"Fault found between {surrogate.treatment} causing {surrogate.outcome}. Contradiction with "
f"expected {surrogate.expected_relationship}."
)
test_result.relationship = (
f"{surrogate.treatment} -> {surrogate.outcome} expected {surrogate.expected_relationship}"
)
return test_result, i + 1, data_collector.data

print("No fault found")
return "No fault found", i + 1, data_collector.data

def generate_surrogates(
self, specification: CausalSpecification, data_collector: ObservationalDataCollector
) -> list[CubicSplineRegressionEstimator]:
"""Generate a surrogate model for each edge of the dag that specifies it is included in the DAG metadata.
:param specification: The Causal Specification (combination of Scenario and Causal Dag)
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
:return: A list of surrogate models
"""
surrogate_models = []

for u, v in specification.causal_dag.graph.edges:
edge_metadata = specification.causal_dag.graph.adj[u][v]
if "included" in edge_metadata:
from_var = specification.scenario.variables.get(u)
to_var = specification.scenario.variables.get(v)
base_test_case = BaseTestCase(from_var, to_var)

minimal_adjustment_set = specification.causal_dag.identification(base_test_case, specification.scenario)

surrogate = CubicSplineRegressionEstimator(
u,
0,
0,
minimal_adjustment_set,
v,
4,
df=data_collector.data,
expected_relationship=edge_metadata["expected"],
)
surrogate_models.append(surrogate)

return surrogate_models
114 changes: 114 additions & 0 deletions causal_testing/surrogate/surrogate_search_algorithms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Module containing implementation of search algorithm for surrogate search """
# Fitness functions are required to be iteratively defined, including all variables within.

from operator import itemgetter
from pygad import GA

from causal_testing.specification.causal_specification import CausalSpecification
from causal_testing.testing.estimators import CubicSplineRegressionEstimator
from causal_testing.surrogate.causal_surrogate_assisted import SearchAlgorithm


class GeneticSearchAlgorithm(SearchAlgorithm):
"""Implementation of SearchAlgorithm class. Implements genetic search algorithm for surrogate models."""

def __init__(self, delta=0.05, config: dict = None) -> None:
super().__init__()

self.delta = delta
self.config = config
self.contradiction_functions = {
christopher-wild marked this conversation as resolved.
Show resolved Hide resolved
"positive": lambda x: -1 * x,
"negative": lambda x: x,
"no_effect": abs,
"some_effect": lambda x: abs(1 / x),
}

# pylint: disable=too-many-locals
def search(
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
) -> list:
solutions = []

for surrogate in surrogate_models:
contradiction_function = self.contradiction_functions[surrogate.expected_relationship]

# The GA fitness function after including required variables into the function's scope
# Unused arguments are required for pygad's fitness function signature
#pylint: disable=cell-var-from-loop
def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
surrogate.control_value = solution[0] - self.delta
surrogate.treatment_value = solution[0] + self.delta

adjustment_dict = {}
for i, adjustment in enumerate(surrogate.adjustment_set):
adjustment_dict[adjustment] = solution[i + 1]

ate = surrogate.estimate_ate_calculated(adjustment_dict)

return contradiction_function(ate)

gene_types, gene_space = self.create_gene_types(surrogate, specification)

ga = GA(
christopher-wild marked this conversation as resolved.
Show resolved Hide resolved
num_generations=200,
num_parents_mating=4,
fitness_func=fitness_function,
sol_per_pop=10,
num_genes=1 + len(surrogate.adjustment_set),
gene_space=gene_space,
gene_type=gene_types,
)

if self.config is not None:
for k, v in self.config.items():
if k == "gene_space":
raise ValueError(
"Gene space should not be set through config. This is generated from the causal "
"specification"
)
setattr(ga, k, v)

ga.run()
solution, fitness, _ = ga.best_solution()

solution_dict = {}
solution_dict[surrogate.treatment] = solution[0]
for idx, adj in enumerate(surrogate.adjustment_set):
solution_dict[adj] = solution[idx + 1]
solutions.append((solution_dict, fitness, surrogate))

return max(solutions, key=itemgetter(1)) # This can be done better with fitness normalisation between edges

@staticmethod
def create_gene_types(
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
) -> tuple[list, list]:
"""Generate the gene_types and gene_space for a given fitness function and specification
:param surrogate_model: Instance of a CubicSplineRegressionEstimator
:param specification: The Causal Specification (combination of Scenario and Causal Dag)"""

var_space = {}
var_space[surrogate_model.treatment] = {}
for adj in surrogate_model.adjustment_set:
var_space[adj] = {}

for relationship in list(specification.scenario.constraints):
rel_split = str(relationship).split(" ")

if rel_split[0] in var_space:
if rel_split[1] == ">=":
var_space[rel_split[0]]["low"] = int(rel_split[2])
elif rel_split[1] == "<=":
var_space[rel_split[0]]["high"] = int(rel_split[2])

gene_space = []
gene_space.append(var_space[surrogate_model.treatment])
for adj in surrogate_model.adjustment_set:
gene_space.append(var_space[adj])

gene_types = []
gene_types.append(specification.scenario.variables.get(surrogate_model.treatment).datatype)
for adj in surrogate_model.adjustment_set:
gene_types.append(specification.scenario.variables.get(adj).datatype)
return gene_types, gene_space
52 changes: 52 additions & 0 deletions causal_testing/testing/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,58 @@
return [ci_low, ci_high]


class CubicSplineRegressionEstimator(LinearRegressionEstimator):
"""A Cubic Spline Regression Estimator is a parametric estimator which restricts the variables in the data to a
combination of parameters and basis functions of the variables.
"""

def __init__(
# pylint: disable=too-many-arguments
self,
treatment: str,
treatment_value: float,
control_value: float,
adjustment_set: set,
outcome: str,
basis: int,
df: pd.DataFrame = None,
effect_modifiers: dict[Variable:Any] = None,
formula: str = None,
alpha: float = 0.05,
expected_relationship=None,
):
super().__init__(
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, formula, alpha
)

self.expected_relationship = expected_relationship

if effect_modifiers is None:
effect_modifiers = []

if formula is None:
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
self.formula = f"{outcome} ~ cr({'+'.join(terms)}, df={basis})"

def estimate_ate_calculated(self, adjustment_config: dict = None) -> float:
model = self._run_linear_regression()

x = {"Intercept": 1, self.treatment: self.treatment_value}
if adjustment_config is not None:
for k, v in adjustment_config.items():
x[k] = v
if self.effect_modifiers is not None:
for k, v in self.effect_modifiers.items():
x[k] = v

Check warning on line 484 in causal_testing/testing/estimators.py

View check run for this annotation

Codecov / codecov/patch

causal_testing/testing/estimators.py#L484

Added line #L484 was not covered by tests

treatment = model.predict(x).iloc[0]

x[self.treatment] = self.control_value
control = model.predict(x).iloc[0]

return treatment - control


class InstrumentalVariableEstimator(Estimator):
"""
Carry out estimation using instrumental variable adjustment rather than conventional adjustment. This means we do
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ dependencies = [
"scipy~=1.7",
"statsmodels~=0.13",
"tabulate~=0.8",
"pydot~=1.4"
"pydot~=1.4",
"pygad~=3.2"
]
dynamic = ["version"]

Expand Down
Loading
Loading