Skip to content

Commit

Permalink
Merge pull request #64 from causy-dev/graph-features-path-analysis
Browse files Browse the repository at this point in the history
feat(graph): check for blocking colliders and noncolliders on paths
  • Loading branch information
this-is-sofia authored Nov 13, 2024
2 parents 2930320 + e43ba30 commit 15a2f1d
Showing 1 changed file with 83 additions and 25 deletions.
108 changes: 83 additions & 25 deletions causy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,78 @@ def edge_of_type_exists(

return True

def descendants_of_node(
self, u: Union[Node, str], visited=None
) -> List[Union[Node, str]]:
"""
Returns a list of all descendants of the given node in a directed graph, including the input node itself.
:param u: The node or node ID to find descendants for.
:return: A list of descendants of a node including the input node itself.
"""

if visited is None:
visited = set()

if isinstance(u, Node):
u = u.id

# If this node has already been visited, return an empty list to avoid cycles
if u in visited:
return []

visited.add(u)
descendants = [self.node_by_id(u)]

if u not in self.edges:
return descendants

for child in self.edges[u]:
if self.directed_edge_exists(u, child) and not self.directed_edge_exists(
child, u
):
if child not in visited:
descendants.extend(self.descendants_of_node(child, visited))

return list(set(descendants))

def _is_a_collider_blocking(self, path, conditioning_set) -> bool:
"""
Check if a path is blocked by a collider which is not in the conditioning set and has no descendants in the conditioning set.
:return: Boolean indicating if the path is blocked
"""
is_path_blocked = False
for i in range(1, len(path) - 1):
if self.edge_of_type_exists(
path[i - 1], path[i], DirectedEdge()
) and self.edge_of_type_exists(path[i + 1], path[i], DirectedEdge()):
# if the node is a collider, check if the node or any of its descendants are in the conditioning set
is_path_blocked = True
for descendant in self.descendants_of_node(path[i]):
if descendant in conditioning_set:
is_path_blocked = False
return is_path_blocked

def _is_a_non_collider_in_conditioning_set(self, path, conditioning_set) -> bool:
"""
Check if a path is blocked by a non-collider which is in the conditioning set.
:param path:
:param conditioning_set:
:return:
"""
is_path_blocked = False
for i in range(1, len(path) - 1):
if path[i] in conditioning_set:
# make sure that node is a noncollider
if not (
self.edge_of_type_exists(path[i - 1].id, path[i].id, DirectedEdge())
and self.edge_of_type_exists(
path[i + 1].id, path[i].id, DirectedEdge()
)
):
# if the node is a non-collider and in the conditioning set, the path is blocked
is_path_blocked = True
return is_path_blocked

def are_nodes_d_separated(
self,
u: Union[Node, str],
Expand All @@ -326,37 +398,23 @@ def are_nodes_d_separated(
if u in conditioning_set or v in conditioning_set:
raise ValueError("Nodes u and v may not be in the conditioning set")

# check whether there is an open path on which all colliders are in the conditioning set and all non-colliders are not in the conditioning set
list_of_results_for_paths = []
# If there are no paths between u and v, they are d-separated
if list(self.all_paths_on_underlying_undirected_graph(u, v)) == []:
return True

list_of_results_for_paths = []
for path in self.all_paths_on_underlying_undirected_graph(u, v):

# If the path only has two nodes, it cannot be blocked and is open. Therefore, u and v are not d-separated
if len(path) == 2:
is_path_blocked = False
return False

for i in range(1, len(path) - 1):
is_path_blocked = False
is_path_blocked = False

# paths are d-separated if a collider is not in the conditioning set
if path[i] not in conditioning_set:
if self.edge_of_type_exists(
path[i - 1].id, path[i].id, DirectedEdge()
) and self.edge_of_type_exists(
path[i + 1].id, path[i].id, DirectedEdge()
):
is_path_blocked = True

# paths are d-separated if a non-collider is in the conditioning set
elif path[i] in conditioning_set:
if not (
self.edge_of_type_exists(
path[i - 1].id, path[i].id, DirectedEdge()
)
and self.edge_of_type_exists(
path[i + 1].id, path[i].id, DirectedEdge()
)
):
is_path_blocked = True
if self._is_a_collider_blocking(
path, conditioning_set
) or self._is_a_non_collider_in_conditioning_set(path, conditioning_set):
is_path_blocked = True

list_of_results_for_paths.append(is_path_blocked)

Expand Down

0 comments on commit 15a2f1d

Please sign in to comment.