-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4121d74
commit c8cd0e6
Showing
12 changed files
with
2,206 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from .planar_graph import PlanarGraph | ||
from .planar_graph import LatticeGraph | ||
from .utils import matrix2edges | ||
from .utils import matrix2ijs | ||
from .utils import matrix2lil | ||
from .utils import matrix2inds | ||
from .utils import ijs2matrix | ||
from .utils import edges2matrix | ||
from .utils import make_symmetric_less | ||
from .utils import make_symmetric_more | ||
from .utils import is_symmetric | ||
|
||
|
||
from .vnn import vnn_graph | ||
from .vnn import estimate_d | ||
from .vnn import vnn_distance | ||
|
||
__all__ = ['PlanarGraph', | ||
'LatticeGraph', | ||
'matrix2ijs', | ||
'matrix2edges', | ||
'matrix2lil', | ||
'matrix2inds', | ||
'make_symmetric_less', | ||
'make_symmetric_more', | ||
'is_symmetric', | ||
'estimate_d', | ||
'ijs2matrix', | ||
'edges2matrix', | ||
'vnn_graph', | ||
'estimate_d', | ||
'vnn_distance', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
import numpy as np | ||
from numpy.lib.stride_tricks import sliding_window_view | ||
from scipy.sparse import lil_matrix, csr_matrix, issparse, hstack, vstack | ||
|
||
from .utils import matrix2inds, edges2matrix | ||
|
||
|
||
def sort_wedges(wedges): | ||
# sort according first column | ||
idx = np.argsort(wedges[:, 0]) | ||
wedges = wedges[idx] | ||
# get split indices | ||
split_ind = np.unique(wedges[:, 0], return_index=True)[1][1:] | ||
# split wedges into wedge_groups | ||
wedge_groups = np.split(wedges, split_ind) | ||
# sort according second column with each group | ||
wedge_groups_ = [group[np.argsort(group[:, 1])] for group in wedge_groups] | ||
return np.vstack(wedge_groups_) | ||
|
||
|
||
def wiki_arctan2(y, x): | ||
angles = np.arctan2(y, x) | ||
angles = (angles + 2 * np.pi) % (2 * np.pi) | ||
return angles | ||
|
||
def get_wedges(pts, inds, sort=True): | ||
# number of pts == len(inds) | ||
# every row in inds stores the neighbors indices | ||
all_wedges = [] | ||
for i, js in enumerate(inds): | ||
if len(js) == 0: | ||
all_wedges.append([i, -1, -1]) | ||
elif len(js) == 1: | ||
all_wedges.append([js[0], -1, js[0]]) | ||
else: | ||
# thetas | ||
thetas = np.array([wiki_arctan2((pts[j] - pts[i])[1], (pts[j] - pts[i])[0]) for j in js]) | ||
# sort js according to thetas | ||
js = np.array(js)[np.argsort(thetas)] | ||
ijs = np.array([(i, j) for j in js]) | ||
# form wedges from ijs | ||
wedges = [(j2, i1, j1) for (i1, j1), (i1, j2) in zip(ijs, np.roll(ijs, 1, axis=0))] | ||
all_wedges.append(wedges) | ||
all_wedges = np.vstack(all_wedges) | ||
# sort all_wedges by first and second columns | ||
if sort: | ||
all_wedges = sort_wedges(all_wedges) | ||
return all_wedges | ||
|
||
|
||
# get polygons function | ||
def search_next_wedge(wedges, current, groups_start_end=None): | ||
# guaranteed to success | ||
if groups_start_end is None: | ||
# get start and end indices of each group | ||
split_ind = np.unique(wedges[:, 0], return_index=True)[1] | ||
split_ind = np.append(split_ind, len(wedges)) | ||
# len(groups_start_end) = len(pts) | ||
groups_start_end = sliding_window_view(split_ind, 2) | ||
# return next wedge which connected to current | ||
i, j, k = current | ||
# locate group j | ||
start, end = groups_start_end[j] | ||
group_j = wedges[start:end] | ||
# within group j, find first occurrence of k | ||
inds = np.where(group_j[:, 1] == k)[0] | ||
if inds.size == 0: | ||
# next wedge cannot be found | ||
next_wedge = None | ||
wedge_idx = None | ||
else: | ||
idx = np.where(group_j[:, 1] == k)[0][0] | ||
next_wedge = group_j[idx] | ||
wedge_idx = start + idx | ||
return next_wedge, wedge_idx | ||
|
||
|
||
def find_polygon(wedges, start_wedge_idx, used=None, groups_start_end=None): | ||
# return polygon indices AND update used | ||
|
||
if groups_start_end is None: | ||
# get start and end indices of each group | ||
split_ind = np.unique(wedges[:, 0], return_index=True)[1] | ||
split_ind = np.append(split_ind, len(wedges)) | ||
# len(groups_start_end) = len(pts) | ||
groups_start_end = sliding_window_view(split_ind, 2) | ||
|
||
if used is None: | ||
used = np.zeros(len(wedges)) | ||
|
||
start_wedge = wedges[start_wedge_idx] | ||
used[start_wedge_idx] = 1 | ||
|
||
if start_wedge[1] == -1: | ||
polygon = None | ||
else: | ||
polygon = [e for e in start_wedge] | ||
start_wedge_copy = start_wedge.copy() | ||
keep_search = True | ||
while keep_search: | ||
# should consider the case where next wedge cannot be found | ||
next_wedge, wedge_idx = search_next_wedge(wedges, start_wedge, groups_start_end) | ||
if next_wedge is None: | ||
# no polygon will be found | ||
keep_search = False | ||
# reset polygon to None | ||
polygon = None | ||
else: | ||
# update used | ||
used[wedge_idx] = 1 | ||
if np.all(next_wedge[-2:] == start_wedge_copy[0:2]): | ||
# no need to search anymore | ||
keep_search = False | ||
# return polygon indices | ||
polygon = np.array(polygon[:-1]) | ||
else: | ||
keep_search = True | ||
# update start_wedge and search | ||
start_wedge = next_wedge | ||
polygon.append(next_wedge[-1]) | ||
return polygon, used | ||
|
||
|
||
def locate_next_start(used): | ||
# return next start from unused | ||
starts = np.where(used == 0)[0] | ||
if starts.size == 0: | ||
return None | ||
else: | ||
return starts[0] | ||
|
||
|
||
def group_wedges(wedges, start=0): | ||
# step 1. sort wedges | ||
wedges = sort_wedges(wedges) | ||
# step 2. mark all wedges as unused, 0 means unused | ||
used = np.zeros(len(wedges)) | ||
|
||
# step 3 | ||
# get start and end indices of each group | ||
split_ind = np.unique(wedges[:, 0], return_index=True)[1] | ||
split_ind = np.append(split_ind, len(wedges)) | ||
# len(groups_start_end) = len(pts) | ||
groups_start_end = sliding_window_view(split_ind, 2) | ||
|
||
# find polygons | ||
polys = [] | ||
while start is not None: | ||
# i marks polygon index | ||
i = 0 | ||
polygon, used = find_polygon(wedges, start, used, groups_start_end) | ||
if polygon is not None: | ||
polys.append(polygon) | ||
# get a new start | ||
start = locate_next_start(used) | ||
return polys | ||
|
||
|
||
def grow_one_edge(pts, ijs): | ||
i = np.argmin(pts[:, 0]) | ||
p = (pts[i, 0] - 1, pts[i, 1]) | ||
pts_ = np.vstack([pts, p]) | ||
j = pts.shape[0] | ||
extra_ij = np.array([(i, j), (j, i)]) | ||
ijs_ = np.vstack([ijs, extra_ij]) | ||
return pts_, ijs_ | ||
|
||
def find_regions(pts, ijs, return_dict=False): | ||
# grow one edge to avoid a Large polygon from contour points | ||
pts_, ijs_ = grow_one_edge(pts, ijs) | ||
shape = (pts_.shape[0], pts_.shape[0]) | ||
matrix_ = edges2matrix(ijs_, shape) | ||
inds = matrix2inds(matrix_) | ||
# form wedges | ||
wedges = get_wedges(pts_, inds) | ||
# group wedges | ||
polys = group_wedges(wedges) | ||
polys = np.array([poly for poly in polys], dtype=object) | ||
if return_dict: | ||
ks = np.array([len(e) for e in polys]) | ||
polys = {str(e):np.vstack(polys[ks==e]) for e in np.unique(ks)} | ||
return polys | ||
else: | ||
return polys |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import pathlib | ||
import numpy as np | ||
import h5py | ||
import matplotlib.pyplot as plt | ||
from matplotlib.collections import LineCollection | ||
from scipy.sparse.csgraph import connected_components | ||
|
||
def is_array_like(a): | ||
return isinstance(a, (list, tuple, range, np.ndarray)) | ||
|
||
|
||
def extract_file_extension(file_name): | ||
ext = pathlib.Path(file_name).suffix[1:] | ||
return ext | ||
|
||
|
||
def remove_file_extension(file_name): | ||
file_name_ = pathlib.Path(file_name).stem | ||
return file_name_ | ||
|
||
def check_array_like(*args): | ||
status = np.array([is_array_like(arg) for arg in args]) | ||
return status | ||
|
||
class SaveGraphMixin: | ||
|
||
def save(self, file_name, path=None): | ||
file_name = remove_file_extension(file_name) | ||
if path is not None: | ||
file_name = path + file_name | ||
# keys and data | ||
names = np.array(['img', 'nodes', 'edges', 'regions',]) | ||
data_tuple = (self.img, self.nodes, self.edges, self.regions) | ||
mask = check_array_like(*data_tuple) | ||
# we need atleast one 'True' to proceed | ||
if not np.any(mask): | ||
raise ValueError('no valid inputs provided') | ||
|
||
names_ = names[mask] | ||
data_ = np.array(data_tuple, dtype=object)[mask] | ||
# create HDF5 file | ||
with h5py.File(file_name + '.hdf5', "w") as f: | ||
# create dataset | ||
for name, ds in zip(names_, data_): | ||
if name == 'regions': | ||
regions_g = f.create_group('regions') | ||
# get unique k | ||
ks = np.array([len(region) for region in self.regions]) | ||
for k in np.unique(ks): | ||
#g = regions_g.create_group(str(k)) | ||
data = np.vstack(ds[ks==k]) | ||
regions_g.create_dataset(str(k), data.shape, data.dtype, data) | ||
else: | ||
data = np.array(ds) | ||
dset = f.create_dataset(name, data.shape, data.dtype, data) | ||
|
||
|
||
class ShowGraphMixin: | ||
|
||
def show(self, ax=None, **kwargs): | ||
if ax is None: | ||
fig, ax = plt.subplots(1, 1, figsize=(7.2, 7.2)) | ||
|
||
if 'color' not in kwargs: | ||
kwargs['color'] = '#2d3742ff' | ||
lines = np.array([(self.pts[i], self.pts[j]) for (i, j) in self.edges]) | ||
segs = LineCollection(lines, **kwargs) | ||
ax.add_collection(segs) | ||
ax.scatter(self.pts[:, 0], self.pts[:, 1], alpha=0) | ||
ax.axis('equal') | ||
|
||
def show_regions(self, ax=None): | ||
if ax is None: | ||
fig, ax = plt.subplots(1, 1, figsize=(7.2, 7.2)) | ||
|
||
for region in self.regions: | ||
if len(region) <= 9: | ||
fc = 'C{}'.format(len(region)-3) | ||
poly = plt.Polygon(self.pts[region], fc=fc, ec='#2d3742', alpha=0.5) | ||
ax.add_patch(poly) | ||
ax.scatter(self.pts[:, 0], self.pts[:, 1], alpha=0) | ||
ax.axis('equal') |
Oops, something went wrong.