Skip to content

Commit

Permalink
update io, impove reading dm files
Browse files Browse the repository at this point in the history
  • Loading branch information
jiadongdan committed Nov 5, 2024
1 parent 4121d74 commit c8cd0e6
Show file tree
Hide file tree
Showing 12 changed files with 2,206 additions and 25 deletions.
2 changes: 1 addition & 1 deletion mtflearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __dir__(self):
from mtflearn.features._zmoments import zmoments
from mtflearn.denoise._denoise_svd import denoise_svd
from mtflearn.denoise._denoise_svd import DenoiseSVD
from mtflearn.io._io_image import load_image
from mtflearn.io._io import load_image

__all__ = ['features',
'denoise',
Expand Down
33 changes: 33 additions & 0 deletions mtflearn/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from .planar_graph import PlanarGraph
from .planar_graph import LatticeGraph
from .utils import matrix2edges
from .utils import matrix2ijs
from .utils import matrix2lil
from .utils import matrix2inds
from .utils import ijs2matrix
from .utils import edges2matrix
from .utils import make_symmetric_less
from .utils import make_symmetric_more
from .utils import is_symmetric


from .vnn import vnn_graph
from .vnn import estimate_d
from .vnn import vnn_distance

__all__ = ['PlanarGraph',
'LatticeGraph',
'matrix2ijs',
'matrix2edges',
'matrix2lil',
'matrix2inds',
'make_symmetric_less',
'make_symmetric_more',
'is_symmetric',
'estimate_d',
'ijs2matrix',
'edges2matrix',
'vnn_graph',
'estimate_d',
'vnn_distance',
]
184 changes: 184 additions & 0 deletions mtflearn/graph/find_regions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
from scipy.sparse import lil_matrix, csr_matrix, issparse, hstack, vstack

from .utils import matrix2inds, edges2matrix


def sort_wedges(wedges):
# sort according first column
idx = np.argsort(wedges[:, 0])
wedges = wedges[idx]
# get split indices
split_ind = np.unique(wedges[:, 0], return_index=True)[1][1:]
# split wedges into wedge_groups
wedge_groups = np.split(wedges, split_ind)
# sort according second column with each group
wedge_groups_ = [group[np.argsort(group[:, 1])] for group in wedge_groups]
return np.vstack(wedge_groups_)


def wiki_arctan2(y, x):
angles = np.arctan2(y, x)
angles = (angles + 2 * np.pi) % (2 * np.pi)
return angles

def get_wedges(pts, inds, sort=True):
# number of pts == len(inds)
# every row in inds stores the neighbors indices
all_wedges = []
for i, js in enumerate(inds):
if len(js) == 0:
all_wedges.append([i, -1, -1])
elif len(js) == 1:
all_wedges.append([js[0], -1, js[0]])
else:
# thetas
thetas = np.array([wiki_arctan2((pts[j] - pts[i])[1], (pts[j] - pts[i])[0]) for j in js])
# sort js according to thetas
js = np.array(js)[np.argsort(thetas)]
ijs = np.array([(i, j) for j in js])
# form wedges from ijs
wedges = [(j2, i1, j1) for (i1, j1), (i1, j2) in zip(ijs, np.roll(ijs, 1, axis=0))]
all_wedges.append(wedges)
all_wedges = np.vstack(all_wedges)
# sort all_wedges by first and second columns
if sort:
all_wedges = sort_wedges(all_wedges)
return all_wedges


# get polygons function
def search_next_wedge(wedges, current, groups_start_end=None):
# guaranteed to success
if groups_start_end is None:
# get start and end indices of each group
split_ind = np.unique(wedges[:, 0], return_index=True)[1]
split_ind = np.append(split_ind, len(wedges))
# len(groups_start_end) = len(pts)
groups_start_end = sliding_window_view(split_ind, 2)
# return next wedge which connected to current
i, j, k = current
# locate group j
start, end = groups_start_end[j]
group_j = wedges[start:end]
# within group j, find first occurrence of k
inds = np.where(group_j[:, 1] == k)[0]
if inds.size == 0:
# next wedge cannot be found
next_wedge = None
wedge_idx = None
else:
idx = np.where(group_j[:, 1] == k)[0][0]
next_wedge = group_j[idx]
wedge_idx = start + idx
return next_wedge, wedge_idx


def find_polygon(wedges, start_wedge_idx, used=None, groups_start_end=None):
# return polygon indices AND update used

if groups_start_end is None:
# get start and end indices of each group
split_ind = np.unique(wedges[:, 0], return_index=True)[1]
split_ind = np.append(split_ind, len(wedges))
# len(groups_start_end) = len(pts)
groups_start_end = sliding_window_view(split_ind, 2)

if used is None:
used = np.zeros(len(wedges))

start_wedge = wedges[start_wedge_idx]
used[start_wedge_idx] = 1

if start_wedge[1] == -1:
polygon = None
else:
polygon = [e for e in start_wedge]
start_wedge_copy = start_wedge.copy()
keep_search = True
while keep_search:
# should consider the case where next wedge cannot be found
next_wedge, wedge_idx = search_next_wedge(wedges, start_wedge, groups_start_end)
if next_wedge is None:
# no polygon will be found
keep_search = False
# reset polygon to None
polygon = None
else:
# update used
used[wedge_idx] = 1
if np.all(next_wedge[-2:] == start_wedge_copy[0:2]):
# no need to search anymore
keep_search = False
# return polygon indices
polygon = np.array(polygon[:-1])
else:
keep_search = True
# update start_wedge and search
start_wedge = next_wedge
polygon.append(next_wedge[-1])
return polygon, used


def locate_next_start(used):
# return next start from unused
starts = np.where(used == 0)[0]
if starts.size == 0:
return None
else:
return starts[0]


def group_wedges(wedges, start=0):
# step 1. sort wedges
wedges = sort_wedges(wedges)
# step 2. mark all wedges as unused, 0 means unused
used = np.zeros(len(wedges))

# step 3
# get start and end indices of each group
split_ind = np.unique(wedges[:, 0], return_index=True)[1]
split_ind = np.append(split_ind, len(wedges))
# len(groups_start_end) = len(pts)
groups_start_end = sliding_window_view(split_ind, 2)

# find polygons
polys = []
while start is not None:
# i marks polygon index
i = 0
polygon, used = find_polygon(wedges, start, used, groups_start_end)
if polygon is not None:
polys.append(polygon)
# get a new start
start = locate_next_start(used)
return polys


def grow_one_edge(pts, ijs):
i = np.argmin(pts[:, 0])
p = (pts[i, 0] - 1, pts[i, 1])
pts_ = np.vstack([pts, p])
j = pts.shape[0]
extra_ij = np.array([(i, j), (j, i)])
ijs_ = np.vstack([ijs, extra_ij])
return pts_, ijs_

def find_regions(pts, ijs, return_dict=False):
# grow one edge to avoid a Large polygon from contour points
pts_, ijs_ = grow_one_edge(pts, ijs)
shape = (pts_.shape[0], pts_.shape[0])
matrix_ = edges2matrix(ijs_, shape)
inds = matrix2inds(matrix_)
# form wedges
wedges = get_wedges(pts_, inds)
# group wedges
polys = group_wedges(wedges)
polys = np.array([poly for poly in polys], dtype=object)
if return_dict:
ks = np.array([len(e) for e in polys])
polys = {str(e):np.vstack(polys[ks==e]) for e in np.unique(ks)}
return polys
else:
return polys
82 changes: 82 additions & 0 deletions mtflearn/graph/mixin_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pathlib
import numpy as np
import h5py
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from scipy.sparse.csgraph import connected_components

def is_array_like(a):
return isinstance(a, (list, tuple, range, np.ndarray))


def extract_file_extension(file_name):
ext = pathlib.Path(file_name).suffix[1:]
return ext


def remove_file_extension(file_name):
file_name_ = pathlib.Path(file_name).stem
return file_name_

def check_array_like(*args):
status = np.array([is_array_like(arg) for arg in args])
return status

class SaveGraphMixin:

def save(self, file_name, path=None):
file_name = remove_file_extension(file_name)
if path is not None:
file_name = path + file_name
# keys and data
names = np.array(['img', 'nodes', 'edges', 'regions',])
data_tuple = (self.img, self.nodes, self.edges, self.regions)
mask = check_array_like(*data_tuple)
# we need atleast one 'True' to proceed
if not np.any(mask):
raise ValueError('no valid inputs provided')

names_ = names[mask]
data_ = np.array(data_tuple, dtype=object)[mask]
# create HDF5 file
with h5py.File(file_name + '.hdf5', "w") as f:
# create dataset
for name, ds in zip(names_, data_):
if name == 'regions':
regions_g = f.create_group('regions')
# get unique k
ks = np.array([len(region) for region in self.regions])
for k in np.unique(ks):
#g = regions_g.create_group(str(k))
data = np.vstack(ds[ks==k])
regions_g.create_dataset(str(k), data.shape, data.dtype, data)
else:
data = np.array(ds)
dset = f.create_dataset(name, data.shape, data.dtype, data)


class ShowGraphMixin:

def show(self, ax=None, **kwargs):
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=(7.2, 7.2))

if 'color' not in kwargs:
kwargs['color'] = '#2d3742ff'
lines = np.array([(self.pts[i], self.pts[j]) for (i, j) in self.edges])
segs = LineCollection(lines, **kwargs)
ax.add_collection(segs)
ax.scatter(self.pts[:, 0], self.pts[:, 1], alpha=0)
ax.axis('equal')

def show_regions(self, ax=None):
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=(7.2, 7.2))

for region in self.regions:
if len(region) <= 9:
fc = 'C{}'.format(len(region)-3)
poly = plt.Polygon(self.pts[region], fc=fc, ec='#2d3742', alpha=0.5)
ax.add_patch(poly)
ax.scatter(self.pts[:, 0], self.pts[:, 1], alpha=0)
ax.axis('equal')
Loading

0 comments on commit c8cd0e6

Please sign in to comment.