diff --git a/CHANGELOG.md b/CHANGELOG.md index d257074ee..8ee3bcbfe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,8 @@ ## Unreleased -- refer to latitude and longitude parameters as lat and lon consistently across package (#1068 #1069) +- fix references to latitude and longitude parameters as lat and lon consistently across package (#1068 #1069) +- fix handling dict and set attribute types when reloading GraphML files (#1075 #1077) ## 1.7.0 (2023-10-11) diff --git a/osmnx/io.py b/osmnx/io.py index 151865a09..5d4325228 100644 --- a/osmnx/io.py +++ b/osmnx/io.py @@ -421,6 +421,15 @@ def _convert_node_attr_types(G, dtypes=None): G : networkx.MultiDiGraph """ for _, data in G.nodes(data=True): + # first, eval stringified lists, dicts, or sets to convert them to objects + # lists, dicts, or sets would be custom attribute types added by a user + for attr, value in data.items(): + if (value.startswith("[") and value.endswith("]")) or ( + value.startswith("{") and value.endswith("}") + ): + with contextlib.suppress(SyntaxError, ValueError): + data[attr] = ast.literal_eval(value) + for attr in data.keys() & dtypes.keys(): data[attr] = dtypes[attr](data[attr]) return G @@ -446,10 +455,13 @@ def _convert_edge_attr_types(G, dtypes=None): # remove extraneous "id" attribute added by graphml saving data.pop("id", None) - # first, eval stringified lists to convert them to list objects + # first, eval stringified lists, dicts, or sets to convert them to objects # edge attributes might have a single value, or a list if simplified + # dicts or sets would be custom attribute types added by a user for attr, value in data.items(): - if value.startswith("[") and value.endswith("]"): + if (value.startswith("[") and value.endswith("]")) or ( + value.startswith("{") and value.endswith("}") + ): with contextlib.suppress(SyntaxError, ValueError): data[attr] = ast.literal_eval(value) diff --git a/tests/test_osmnx.py b/tests/test_osmnx.py index ebd63e8e3..e4c1d280c 100755 --- a/tests/test_osmnx.py +++ b/tests/test_osmnx.py @@ -441,6 +441,8 @@ def test_endpoints(): def test_graph_save_load(): """Test saving/loading graphs to/from disk.""" + fp = Path(ox.settings.data_folder) / "graph.graphml" + # save graph as shapefile and geopackage G = ox.graph_from_point(location_point, dist=500, network_type="drive") ox.save_graph_shapefile(G, directed=True) @@ -474,13 +476,28 @@ def test_graph_save_load(): edge_attrs = {n: bool(b) for n, b in zip(G.edges, bools)} nx.set_edge_attributes(G, edge_attrs, attr_name) + # create list, set, and dict attributes for nodes and edges + rand_ints_nodes = np.random.randint(0, 10, len(G.nodes)) + rand_ints_edges = np.random.randint(0, 10, len(G.edges)) + list_node_attrs = {n: [n, r] for n, r in zip(G.nodes, rand_ints_nodes)} + nx.set_node_attributes(G, list_node_attrs, "test_list") + list_edge_attrs = {e: [e, r] for e, r in zip(G.edges, rand_ints_edges)} + nx.set_edge_attributes(G, list_edge_attrs, "test_list") + set_node_attrs = {n: {n, r} for n, r in zip(G.nodes, rand_ints_nodes)} + nx.set_node_attributes(G, set_node_attrs, "test_set") + set_edge_attrs = {e: {e, r} for e, r in zip(G.edges, rand_ints_edges)} + nx.set_edge_attributes(G, set_edge_attrs, "test_set") + dict_node_attrs = {n: {n: r} for n, r in zip(G.nodes, rand_ints_nodes)} + nx.set_node_attributes(G, dict_node_attrs, "test_dict") + dict_edge_attrs = {e: {e: r} for e, r in zip(G.edges, rand_ints_edges)} + nx.set_edge_attributes(G, dict_edge_attrs, "test_dict") + # save/load graph as graphml file ox.save_graphml(G, gephi=True) ox.save_graphml(G, gephi=False) - ox.save_graphml(G, gephi=False, filepath=Path(ox.settings.data_folder) / "graph.graphml") - filepath = Path(ox.settings.data_folder) / "graph.graphml" + ox.save_graphml(G, gephi=False, filepath=fp) G2 = ox.load_graphml( - filepath, + fp, graph_dtypes={attr_name: ox.io._convert_bool_string}, node_dtypes={attr_name: ox.io._convert_bool_string}, edge_dtypes={attr_name: ox.io._convert_bool_string}, @@ -505,7 +522,7 @@ def test_graph_save_load(): # test custom data types nd = {"osmid": str} ed = {"length": str, "osmid": float} - G2 = ox.load_graphml(filepath, node_dtypes=nd, edge_dtypes=ed) + G2 = ox.load_graphml(fp, node_dtypes=nd, edge_dtypes=ed) # test loading graphml from a file stream file_bytes = Path.open(Path("tests/input_data/short.graphml"), "rb").read()