Skip to content

Commit

Permalink
Merge pull request #46 from hpi-epic/graph_structure_from_file
Browse files Browse the repository at this point in the history
spike for loading fixed graph structure from .gml file
  • Loading branch information
JohannesHuegle authored Jul 13, 2022
2 parents 0422487 + 3339127 commit 986182d
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 18 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ python3 -m twine upload dist/*
| conditional_gaussian | 0 or 1 | 1 | '1' Defines that conditional gaussian model is assumed for a mixture of variables. Otherwise '0', discrete variables can have continuous parents. |
| beta_lower_limit | (0, Inf) | 0.5 | Lower limit for beta values for influence of continuous parents. Betas are sampled uniform from the union of [-upper,-lower] and [lower,upper]. Upper limit see below. |
| beta_upper_limit | (0, Inf) | 1 | Upper limit for beta values for influence of continuous parents. Betas are sampled uniform from the union of [-upper,-lower] and [lower,upper]. Lower limit see above. |
| graph_structure_file | | None | Defines a path to a .gml file for a fixed DAG structure (ignoring node and edge characteristics) used during manm_cs graph building. Note graph_structure_file is mutually exclusive to num_nodes and edge_density. |

## License

Expand Down
48 changes: 34 additions & 14 deletions manm_cs/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,15 @@ def parse_args():
parser = argparse.ArgumentParser(
description='Generate a dataset for benchmarking causal structure learning using '
'the mixed additive noise model')
parser.add_argument('--num_nodes', type=type_in_range(int, 1, None), required=True,
group1 = parser.add_mutually_exclusive_group(required=True)
group1.add_argument('--num_nodes', type=type_in_range(int, 1, None),
help='Defines the number of nodes to be in the generated DAG.')
parser.add_argument('--edge_density', type=type_in_range(float, 0.0, 1.0), required=True,
arg_nx_file = group1.add_argument('--graph_structure_file', type=str, required=False,
help='valid .gml file to load a fixed graph structure to networkx.DiGraph structure.')
group2 = parser.add_mutually_exclusive_group(required=True)
group2.add_argument('--edge_density', type=type_in_range(float, 0.0, 1.0),
help='Defines the density of edges in the generated DAG.')
group2._group_actions.append(arg_nx_file)
parser.add_argument('--discrete_node_ratio', type=type_in_range(float, 0.0, 1.0), required=True,
help='Defines the percentage of nodes that shall be of discrete type. Depending on its value '
'the appropriate model (multivariate normal, mixed gaussian, discrete only) is chosen.')
Expand Down Expand Up @@ -132,18 +137,33 @@ def parse_args():


def graph_from_args(args) -> Graph:
return GraphBuilder() \
.with_num_nodes(args.num_nodes) \
.with_edge_density(args.edge_density) \
.with_discrete_node_ratio(args.discrete_node_ratio) \
.with_discrete_signal_to_noise_ratio(args.discrete_signal_to_noise_ratio) \
.with_min_discrete_value_classes(args.min_discrete_value_classes) \
.with_max_discrete_value_classes(args.max_discrete_value_classes) \
.with_continuous_noise_std(args.continuous_noise_std) \
.with_functions(args.functions) \
.with_conditional_gaussian(args.conditional_gaussian) \
.with_betas(args.beta_lower_limit, args.beta_upper_limit) \
.build()
if args.graph_structure_file:
dag = nx.read_gml(args.graph_structure_file)
assert nx.is_directed_acyclic_graph(dag)
return GraphBuilder() \
.with_networkx_DiGraph(dag) \
.with_discrete_node_ratio(args.discrete_node_ratio) \
.with_discrete_signal_to_noise_ratio(args.discrete_signal_to_noise_ratio) \
.with_min_discrete_value_classes(args.min_discrete_value_classes) \
.with_max_discrete_value_classes(args.max_discrete_value_classes) \
.with_continuous_noise_std(args.continuous_noise_std) \
.with_functions(args.functions) \
.with_conditional_gaussian(args.conditional_gaussian) \
.with_betas(args.beta_lower_limit, args.beta_upper_limit) \
.build()
else:
return GraphBuilder() \
.with_num_nodes(args.num_nodes) \
.with_edge_density(args.edge_density) \
.with_discrete_node_ratio(args.discrete_node_ratio) \
.with_discrete_signal_to_noise_ratio(args.discrete_signal_to_noise_ratio) \
.with_min_discrete_value_classes(args.min_discrete_value_classes) \
.with_max_discrete_value_classes(args.max_discrete_value_classes) \
.with_continuous_noise_std(args.continuous_noise_std) \
.with_functions(args.functions) \
.with_conditional_gaussian(args.conditional_gaussian) \
.with_betas(args.beta_lower_limit, args.beta_upper_limit) \
.build()


if __name__ == '__main__':
Expand Down
26 changes: 22 additions & 4 deletions manm_cs/graph/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import networkx as nx
import numpy as np
import random
from validation import validate_int, validate_float, validate_bool
from validation import validate_int, validate_float, validate_bool, validate_text

from manm_cs.graph import Graph
from manm_cs.noise import GaussianNoiseBuilder, DiscreteNoiseBuilder
Expand Down Expand Up @@ -32,6 +32,20 @@ class GraphBuilder:

functions: List[Tuple[float, Callable[...,float]]]

dag: Optional[nx.DiGraph] = None

def with_graph_structure_file(self, file_name: str) -> 'GraphBuilder':
validate_text(file_name, pattern='.*\.gml')
self.graph_structure_file_name = file_name
return self

def with_networkx_DiGraph(self, graph: 'DiGraph') -> 'GraphBuilder':
assert nx.is_directed_acyclic_graph(graph)
self.with_num_nodes(nx.number_of_nodes(graph))
self.with_edge_density(nx.density(graph))
self.dag = graph
return self

def with_num_nodes(self, num_nodes: int) -> 'GraphBuilder':
validate_int(num_nodes, min_value=1)
self.num_nodes = num_nodes
Expand Down Expand Up @@ -164,23 +178,27 @@ def generate_continuous_variable(self, parents, node_idx) -> 'ContinuousVariable
return ContinuousVariable(idx=node_idx, parents=parents, functions=functions,
noise=noise, betas=betas)

def build(self, seed: int = 0) -> Graph:
def generate_dag(self, seed: int) -> 'DiGraph':
# Generate graph using networkx package
G = nx.gnp_random_graph(n=self.num_nodes, p=self.edge_density, seed=seed, directed=True)
# Convert generated graph to DAG
dag = nx.DiGraph()
dag.add_nodes_from(G)
dag.add_edges_from([(u, v, {}) for (u, v) in G.edges() if u < v])
assert nx.is_directed_acyclic_graph(dag)
return dag

def build(self, seed: int = 0) -> Graph:
self.dag = self.generate_dag(seed) if self.dag is None else self.dag

# Create list of topologically sorted nodes
# Note, nodes are ordered already, sorting step may become relevant, if graph generation above is changed
top_sort_idx = list(nx.topological_sort(dag))
top_sort_idx = list(nx.topological_sort(self.dag))
num_discrete_nodes = int(self.discrete_node_ratio * self.num_nodes)

variables_by_idx: Dict[int, Variable] = {}
for i, node_idx in enumerate(top_sort_idx):
parents = [variables_by_idx[idx] for idx in sorted(list(dag.predecessors(node_idx)))]
parents = [variables_by_idx[idx] for idx in sorted(list(self.dag.predecessors(node_idx)))]

# Conditional Gaussian:
if self.conditional_gaussian == True:
Expand Down

0 comments on commit 986182d

Please sign in to comment.