From 8eeb1c21fb7fc725bacf644adf8a14feb9d43822 Mon Sep 17 00:00:00 2001 From: Christopher Schmidt Date: Tue, 12 Jul 2022 22:56:48 +0200 Subject: [PATCH 1/4] spike for loading fixed graph structure from .gml file --- manm_cs/__main__.py | 39 +++++++++++++++++++++++----------- manm_cs/graph/graph_builder.py | 26 +++++++++++++++++++++-- 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/manm_cs/__main__.py b/manm_cs/__main__.py index ffc78ab..001e5f2 100644 --- a/manm_cs/__main__.py +++ b/manm_cs/__main__.py @@ -119,6 +119,8 @@ def parse_args(): parser.add_argument('--output_samples_file', type=str, required=False, default=SAMPLES_FILE, help='Output file (path) for the generated samples csv. Relative to the directory from which the library is executed.' 'Specify without file extension.') + parser.add_argument('--graph_structure_file', type=str, required=False, + help='.gml file to load a fixed graph structure.') args = parser.parse_args() assert args.min_discrete_value_classes <= args.max_discrete_value_classes, \ @@ -132,18 +134,31 @@ def parse_args(): def graph_from_args(args) -> Graph: - return GraphBuilder() \ - .with_num_nodes(args.num_nodes) \ - .with_edge_density(args.edge_density) \ - .with_discrete_node_ratio(args.discrete_node_ratio) \ - .with_discrete_signal_to_noise_ratio(args.discrete_signal_to_noise_ratio) \ - .with_min_discrete_value_classes(args.min_discrete_value_classes) \ - .with_max_discrete_value_classes(args.max_discrete_value_classes) \ - .with_continuous_noise_std(args.continuous_noise_std) \ - .with_functions(args.functions) \ - .with_conditional_gaussian(args.conditional_gaussian) \ - .with_betas(args.beta_lower_limit, args.beta_upper_limit) \ - .build() + if args.graph_structure_file: + return GraphBuilder() \ + .with_graph_structure_file(args.graph_structure_file) \ + .with_discrete_node_ratio(args.discrete_node_ratio) \ + .with_discrete_signal_to_noise_ratio(args.discrete_signal_to_noise_ratio) \ + .with_min_discrete_value_classes(args.min_discrete_value_classes) \ + .with_max_discrete_value_classes(args.max_discrete_value_classes) \ + .with_continuous_noise_std(args.continuous_noise_std) \ + .with_functions(args.functions) \ + .with_conditional_gaussian(args.conditional_gaussian) \ + .with_betas(args.beta_lower_limit, args.beta_upper_limit) \ + .build() + else: + return GraphBuilder() \ + .with_num_nodes(args.num_nodes) \ + .with_edge_density(args.edge_density) \ + .with_discrete_node_ratio(args.discrete_node_ratio) \ + .with_discrete_signal_to_noise_ratio(args.discrete_signal_to_noise_ratio) \ + .with_min_discrete_value_classes(args.min_discrete_value_classes) \ + .with_max_discrete_value_classes(args.max_discrete_value_classes) \ + .with_continuous_noise_std(args.continuous_noise_std) \ + .with_functions(args.functions) \ + .with_conditional_gaussian(args.conditional_gaussian) \ + .with_betas(args.beta_lower_limit, args.beta_upper_limit) \ + .build() if __name__ == '__main__': diff --git a/manm_cs/graph/graph_builder.py b/manm_cs/graph/graph_builder.py index 9a36093..43d0cbc 100644 --- a/manm_cs/graph/graph_builder.py +++ b/manm_cs/graph/graph_builder.py @@ -3,7 +3,7 @@ import networkx as nx import numpy as np import random -from validation import validate_int, validate_float, validate_bool +from validation import validate_int, validate_float, validate_bool, validate_text from manm_cs.graph import Graph from manm_cs.noise import GaussianNoiseBuilder, DiscreteNoiseBuilder @@ -32,6 +32,13 @@ class GraphBuilder: functions: List[Tuple[float, Callable[...,float]]] + graph_structure_file_name: Optional[str] = None + + def with_graph_structure_file(self, file_name: str) -> 'GraphBuilder': + validate_text(file_name, pattern='.*\.gml') + self.graph_structure_file_name = file_name + return self + def with_num_nodes(self, num_nodes: int) -> 'GraphBuilder': validate_int(num_nodes, min_value=1) self.num_nodes = num_nodes @@ -164,7 +171,15 @@ def generate_continuous_variable(self, parents, node_idx) -> 'ContinuousVariable return ContinuousVariable(idx=node_idx, parents=parents, functions=functions, noise=noise, betas=betas) - def build(self, seed: int = 0) -> Graph: + def load_dag_from_file(self) -> 'DiGraph': + # Load graph using networkx package + dag = nx.read_gml(self.graph_structure_file_name) + assert nx.is_directed_acyclic_graph(dag) + self.with_num_nodes(nx.number_of_nodes(dag)) + self.with_edge_density(nx.density(dag)) + return dag + + def generate_dag_using_networkx(self, seed: int) -> 'DiGraph': # Generate graph using networkx package G = nx.gnp_random_graph(n=self.num_nodes, p=self.edge_density, seed=seed, directed=True) # Convert generated graph to DAG @@ -172,6 +187,13 @@ def build(self, seed: int = 0) -> Graph: dag.add_nodes_from(G) dag.add_edges_from([(u, v, {}) for (u, v) in G.edges() if u < v]) assert nx.is_directed_acyclic_graph(dag) + return dag + + def generate_or_load_dag(self, seed: int) -> 'DiGraph': + return self.load_dag_from_file() if self.graph_structure_file_name else self.generate_dag_using_networkx(seed) + + def build(self, seed: int = 0) -> Graph: + dag = self.generate_or_load_dag(seed) # Create list of topologically sorted nodes # Note, nodes are ordered already, sorting step may become relevant, if graph generation above is changed From 3903cc67cdb6686ee43dc07c28e74594682fe0bb Mon Sep 17 00:00:00 2001 From: Christopher Schmidt Date: Wed, 13 Jul 2022 11:16:57 +0200 Subject: [PATCH 2/4] make graph_structure_file mutual exclusive to num_nodes and edge_density for command_line options --- manm_cs/__main__.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/manm_cs/__main__.py b/manm_cs/__main__.py index 001e5f2..90db02f 100644 --- a/manm_cs/__main__.py +++ b/manm_cs/__main__.py @@ -74,10 +74,15 @@ def parse_args(): parser = argparse.ArgumentParser( description='Generate a dataset for benchmarking causal structure learning using ' 'the mixed additive noise model') - parser.add_argument('--num_nodes', type=type_in_range(int, 1, None), required=True, + group1 = parser.add_mutually_exclusive_group(required=True) + group1.add_argument('--num_nodes', type=type_in_range(int, 1, None), help='Defines the number of nodes to be in the generated DAG.') - parser.add_argument('--edge_density', type=type_in_range(float, 0.0, 1.0), required=True, + arg_nx_file = group1.add_argument('--graph_structure_file', type=str, required=False, + help='valid .gml file to load a fixed graph structure to networkx.DiGraph structure.') + group2 = parser.add_mutually_exclusive_group(required=True) + group2.add_argument('--edge_density', type=type_in_range(float, 0.0, 1.0), help='Defines the density of edges in the generated DAG.') + group2._group_actions.append(arg_nx_file) parser.add_argument('--discrete_node_ratio', type=type_in_range(float, 0.0, 1.0), required=True, help='Defines the percentage of nodes that shall be of discrete type. Depending on its value ' 'the appropriate model (multivariate normal, mixed gaussian, discrete only) is chosen.') @@ -119,8 +124,6 @@ def parse_args(): parser.add_argument('--output_samples_file', type=str, required=False, default=SAMPLES_FILE, help='Output file (path) for the generated samples csv. Relative to the directory from which the library is executed.' 'Specify without file extension.') - parser.add_argument('--graph_structure_file', type=str, required=False, - help='.gml file to load a fixed graph structure.') args = parser.parse_args() assert args.min_discrete_value_classes <= args.max_discrete_value_classes, \ @@ -135,8 +138,10 @@ def parse_args(): def graph_from_args(args) -> Graph: if args.graph_structure_file: + dag = nx.read_gml(args.graph_structure_file) + assert nx.is_directed_acyclic_graph(dag) return GraphBuilder() \ - .with_graph_structure_file(args.graph_structure_file) \ + .with_networkx_DiGraph(dag) \ .with_discrete_node_ratio(args.discrete_node_ratio) \ .with_discrete_signal_to_noise_ratio(args.discrete_signal_to_noise_ratio) \ .with_min_discrete_value_classes(args.min_discrete_value_classes) \ From 7add298d34f80b4234fab84a4e365798f51a1f74 Mon Sep 17 00:00:00 2001 From: Christopher Schmidt Date: Wed, 13 Jul 2022 11:19:53 +0200 Subject: [PATCH 3/4] change internal graph builder function to take graph structure from networkx diGraph object instead of loading from .gml file via networkx, goal is to improve compatability with networkx format --- manm_cs/graph/graph_builder.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/manm_cs/graph/graph_builder.py b/manm_cs/graph/graph_builder.py index 43d0cbc..42a3f92 100644 --- a/manm_cs/graph/graph_builder.py +++ b/manm_cs/graph/graph_builder.py @@ -32,13 +32,20 @@ class GraphBuilder: functions: List[Tuple[float, Callable[...,float]]] - graph_structure_file_name: Optional[str] = None + dag: Optional[nx.DiGraph] = None def with_graph_structure_file(self, file_name: str) -> 'GraphBuilder': validate_text(file_name, pattern='.*\.gml') self.graph_structure_file_name = file_name return self + def with_networkx_DiGraph(self, graph: 'DiGraph') -> 'GraphBuilder': + assert nx.is_directed_acyclic_graph(graph) + self.with_num_nodes(nx.number_of_nodes(graph)) + self.with_edge_density(nx.density(graph)) + self.dag = graph + return self + def with_num_nodes(self, num_nodes: int) -> 'GraphBuilder': validate_int(num_nodes, min_value=1) self.num_nodes = num_nodes @@ -171,15 +178,7 @@ def generate_continuous_variable(self, parents, node_idx) -> 'ContinuousVariable return ContinuousVariable(idx=node_idx, parents=parents, functions=functions, noise=noise, betas=betas) - def load_dag_from_file(self) -> 'DiGraph': - # Load graph using networkx package - dag = nx.read_gml(self.graph_structure_file_name) - assert nx.is_directed_acyclic_graph(dag) - self.with_num_nodes(nx.number_of_nodes(dag)) - self.with_edge_density(nx.density(dag)) - return dag - - def generate_dag_using_networkx(self, seed: int) -> 'DiGraph': + def generate_dag(self, seed: int) -> 'DiGraph': # Generate graph using networkx package G = nx.gnp_random_graph(n=self.num_nodes, p=self.edge_density, seed=seed, directed=True) # Convert generated graph to DAG @@ -189,20 +188,17 @@ def generate_dag_using_networkx(self, seed: int) -> 'DiGraph': assert nx.is_directed_acyclic_graph(dag) return dag - def generate_or_load_dag(self, seed: int) -> 'DiGraph': - return self.load_dag_from_file() if self.graph_structure_file_name else self.generate_dag_using_networkx(seed) - def build(self, seed: int = 0) -> Graph: - dag = self.generate_or_load_dag(seed) + self.dag = self.generate_dag(seed) if self.dag is None else self.dag # Create list of topologically sorted nodes # Note, nodes are ordered already, sorting step may become relevant, if graph generation above is changed - top_sort_idx = list(nx.topological_sort(dag)) + top_sort_idx = list(nx.topological_sort(self.dag)) num_discrete_nodes = int(self.discrete_node_ratio * self.num_nodes) variables_by_idx: Dict[int, Variable] = {} for i, node_idx in enumerate(top_sort_idx): - parents = [variables_by_idx[idx] for idx in sorted(list(dag.predecessors(node_idx)))] + parents = [variables_by_idx[idx] for idx in sorted(list(self.dag.predecessors(node_idx)))] # Conditional Gaussian: if self.conditional_gaussian == True: From 333912775d013b7be1eb7f32810263b67f4cd336 Mon Sep 17 00:00:00 2001 From: Christopher Schmidt Date: Wed, 13 Jul 2022 11:27:48 +0200 Subject: [PATCH 4/4] update readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 94eb218..0bc4440 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,7 @@ python3 -m twine upload dist/* | conditional_gaussian | 0 or 1 | 1 | '1' Defines that conditional gaussian model is assumed for a mixture of variables. Otherwise '0', discrete variables can have continuous parents. | | beta_lower_limit | (0, Inf) | 0.5 | Lower limit for beta values for influence of continuous parents. Betas are sampled uniform from the union of [-upper,-lower] and [lower,upper]. Upper limit see below. | | beta_upper_limit | (0, Inf) | 1 | Upper limit for beta values for influence of continuous parents. Betas are sampled uniform from the union of [-upper,-lower] and [lower,upper]. Lower limit see above. | +| graph_structure_file | | None | Defines a path to a .gml file for a fixed DAG structure (ignoring node and edge characteristics) used during manm_cs graph building. Note graph_structure_file is mutually exclusive to num_nodes and edge_density. | ## License