From 8e9fe021a8b88bdfd7f877fb33dd7b9000a16e20 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Mon, 31 Jul 2023 15:17:40 -0700 Subject: [PATCH 01/37] adding code for sift generation of point matches for stitching --- .../stitching_modules/acstitch/io.py | 27 ++ .../stitching_modules/acstitch/sift_stitch.py | 238 ++++++++++++++++++ .../stitching_modules/acstitch/stitch.py | 19 ++ .../stitching_modules/acstitch/utils.py | 63 +++++ .../stitching_modules/acstitch/zarrutils.py | 16 ++ 5 files changed, 363 insertions(+) create mode 100644 acpreprocessing/stitching_modules/acstitch/io.py create mode 100644 acpreprocessing/stitching_modules/acstitch/sift_stitch.py create mode 100644 acpreprocessing/stitching_modules/acstitch/stitch.py create mode 100644 acpreprocessing/stitching_modules/acstitch/utils.py create mode 100644 acpreprocessing/stitching_modules/acstitch/zarrutils.py diff --git a/acpreprocessing/stitching_modules/acstitch/io.py b/acpreprocessing/stitching_modules/acstitch/io.py new file mode 100644 index 0000000..dd1762f --- /dev/null +++ b/acpreprocessing/stitching_modules/acstitch/io.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +""" +Created on Mon Jul 31 13:33:46 2023 + +@author: kevint +""" + +import json +import gzip +import numpy + +class NumpyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, numpy.ndarray): + return obj.tolist() + return json.JSONEncoder.default(self, obj) + + +def save_pointmatch_file(pmdata,jsonpath): + with gzip.open(jsonpath, 'w') as fout: + fout.write(json.dumps(pmdata).encode('utf-8')) + + +def read_pointmatch_file(jsonpath): + with gzip.open(jsonpath, 'r') as fin: + data = json.loads(fin.read().decode('utf-8')) + return data \ No newline at end of file diff --git a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py new file mode 100644 index 0000000..a7026e4 --- /dev/null +++ b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py @@ -0,0 +1,238 @@ +# code based on Russel's em-stitch notebook for lens correction +# defines a sift_detector class and utility functions + +import numpy as np +import cv2 as cv +from matplotlib import pyplot as plt + +class SiftDetector(object): + def __init__(self,clahe_kwargs,sift_kwargs,flann_args,ratio=0.7,min_inliers=100): + # CLAHE equalizer + self.clahe = cv.createCLAHE(**clahe_kwargs) + # Initiate SIFT detector + self.sift = cv.SIFT_create(**sift_kwargs) + # FLANN feature matcher + self.flann = cv.FlannBasedMatcher(*flann_args) + self.ratio = ratio + self.minin = min_inliers + + def detect_keypoints(self,img): + cimg1 = self.clahe.apply(img) + cimg1 = np.sqrt(cimg1).astype('uint8') + kp1, des1 = self.sift.detectAndCompute(cimg1,None) + return kp1,des1,cimg1 + + def detect_matches(self,kp1,des1,cimg1,img,draw=False): + cimg2 = self.clahe.apply(img) + cimg2 = np.sqrt(cimg2).astype('uint8') + # find the keypoints and descriptors with SIFT + kp2, des2 = self.sift.detectAndCompute(cimg2,None) + matches = self.flann.knnMatch(des1,des2,k=2) + # Need to draw only good matches, so create a mask + matchesMask = [[0,0] for i in range(len(matches))] + # ratio test as per Lowe's paper + good = [] + for i,(m,n) in enumerate(matches): + if m.distance < self.ratio*n.distance: + matchesMask[i]=[1,0] + good.append(m) + print(len(good)) + k1xy = np.array([np.array(k.pt) for k in kp1]) + k2xy = np.array([np.array(k.pt) for k in kp2]) + k1 = [] + k2 = [] + if len(good)>self.minin: + src_pts = np.float32([ kp1[m.queryIdx].pt for m in good ]).reshape(-1,1,2) + dst_pts = np.float32([ kp2[m.trainIdx].pt for m in good ]).reshape(-1,1,2) + M, mask = cv.findHomography(src_pts, dst_pts, cv.RANSAC,5.0) + matchesMask = mask.ravel().tolist() + if draw: + draw_params = dict(matchColor = (0,255,0), # draw matches in green color + singlePointColor = None, + matchesMask = matchesMask, # draw only inliers + flags = 2) + img4 = cv.drawMatches(cimg1,kp1,cimg2,kp2,good,None,**draw_params) + plt.figure(figsize=(20,20)) + plt.imshow(img4, 'gray'),plt.show() + + good = np.array(good)[np.array(matchesMask).astype('bool')] + imgIdx = np.array([g.imgIdx for g in good]) + tIdx = np.array([g.trainIdx for g in good]) + qIdx = np.array([g.queryIdx for g in good]) + for i in range(len(tIdx)): + if imgIdx[i] == 1: + k1.append(k1xy[tIdx[i]]) + k2.append(k2xy[qIdx[i]]) + else: + k1.append(k1xy[qIdx[i]]) + k2.append(k2xy[tIdx[i]]) + + if len(k1)>0: + k1 = np.array(k1) + k2 = np.array(k2) + # limit number of matches to random subset if too many + if k1.shape[0] > 10000: + a = np.arange(k1.shape[0]) + np.random.shuffle(a) + k1 = k1[a[0: 10000], :] + k2 = k2[a[0: 10000], :] + + return k1,k2 + + def detect_and_combine(self,kp1,des1,cimg1,imgStack,draw=False,axis=2,max_only=False): + k1_tot = [] + k2_tot = [] + good = np.zeros(imgStack.shape[axis],dtype=int) + islice = [] + for i in range(imgStack.shape[axis]): + if axis==2: + img = imgStack[:,:,i] + elif axis==1: + img = imgStack[:,i,:] + k1,k2 = self.detect_matches(kp1,des1,cimg1,img,draw) + if isinstance(k1, np.ndarray): + k1_tot.append(k1) + k2_tot.append(k2) + good[i] = k1.shape[0] + islice.append(np.ones(k1.shape,dtype=int)*i) + if k1_tot: + if max_only: + imax = np.argmax(good[good>0]) + return k1_tot[imax],k2_tot[imax],good,islice[imax] + else: + return np.concatenate(k1_tot),np.concatenate(k2_tot),good,np.concatenate(islice) + else: + return None,None,None,None + + + def stitch_over_segments(self,slice_axes,p_dslist,q_dslist,zstarts,zlength,**kwargs): + pmlist = [] + if slice_axes == "zx": + for zs in zstarts: + seglist = self.run_zx_stitch(p_dslist,q_dslist,zs,zs+zlength,**kwargs) + pmlist.append(seglist) + if pmlist: + n_tiles = len(pmlist[0][0]) + tile_pmlist = [[[] for i in range(n_tiles)],[[] for i in range(n_tiles)]] + for pm in pmlist: + for i in range(2): + for ii in range(n_tiles): + tile_pmlist[i][ii].append(pm[i][ii]) + p_ptlist = [np.concatenate(pm) for pm in tile_pmlist[0]] + q_ptlist = [np.concatenate(pm) for pm in tile_pmlist[1]] + return p_ptlist,q_ptlist + else: + return None,None + + + def run_zx_stitch(self, + p_srclist, + q_srclist, + z0,z1,i_slice,j_slice,ny,dy, + scatter=False): + # estimate translation between strips with the median 2D displacement of matching point correspondences + siftklist = [] + j_slices = np.ones(len(p_srclist),dtype=int)*j_slice + for i,dsRef in enumerate(p_srclist): + ji = j_slices[i] + # iterate over each strip and its subsequent neighbor to look for correspondences and estimate median + dsStack = q_srclist[i] + imgRef = dsRef[0,0,z0:z1,i_slice,:] + # detect SIFT keypoints for reference slice + kp1, des1, cimg1 = self.detect_keypoints(imgRef) + imgStack = dsStack[0,0,z0:z1,(ji-ny*dy):(ji+(ny+1)*dy):dy,:] + # detect correspondences in slices from neighboring strip + k1_tot,k2_tot,good,k2slice = self.detect_and_combine(kp1,des1,cimg1,imgStack,False,axis=1,max_only=True) + print("Number of correspondences: " + str(good)) + if not k1_tot is None and k1_tot.shape[0]>200: + k = k2_tot-k1_tot + print('total correspondences for analysis: ' + str(k.shape[0])) + # estimate stitching translation with median displacement + km = np.median(k,axis=0) + print('median pixel displacements:' + str(km)) + # display scatter of displacements around median estimate + if scatter: + plt.scatter(k[:,0]-km[0],k[:,1]-km[1],s=1) + plt.xlim((-5,5)) + plt.ylim((-5,5)) + plt.show() + # identify slice index with most correspondences + j_slices[i] = ji - ny*dy + dy*np.argmax(good) + siftklist.append((k1_tot,k2_tot,k2slice)) + else: + print("not enough correspondences for strip " + str(i)) + siftklist.append(None) + siftstitch = [[],[]] + for i,s in enumerate(siftklist): + zi = z0 + yi = j_slices[i] + for ii,t in enumerate(siftstitch): + if not s is None: + kzyx = np.empty((s[ii].shape[0],3)) + kzyx[:,0] = s[ii][:,0] + zi + kzyx[:,1] = i_slice if ii == 0 else yi + kzyx[:,2] = s[ii][:,1] + t.append(kzyx) + else: + t.append(None) + return siftstitch + + def run_zy_stitch(sift_detector, + p_srclist, + q_srclist, + z0,z1,i_slice,j_slice,nx,dx, + scatter=False): + # estimate translation between strips with the median 2D displacement of matching point correspondences + kms = np.zeros((len(p_srclist),2)) + siftklist = [] + j_slices = np.ones(len(p_srclist),dtype=int)*j_slice + for i,dsRef in enumerate(p_srclist): + ji = j_slices[i] + # iterate over each strip and its subsequent neighbor to look for correspondences and estimate median + dsStack = q_srclist[i] + imgRef = dsRef[0,0,z0:z1,:,i_slice] + # detect SIFT keypoints for reference slice + kp1, des1, cimg1 = sift_detector.detect_keypoints(imgRef) + imgStack = dsStack[0,0,z0:z1,:,(ji-nx*dx):(ji+(nx+1)*dx):dx] + # detect correspondences in slices from neighboring strip + k1_tot,k2_tot,good,k2slice = sift_detector.detect_and_combine(kp1,des1,cimg1,imgStack,False,axis=2,max_only=False) + print("Number of correspondences: " + str(good)) + if not k1_tot is None and k1_tot.shape[0]>50: + k = k2_tot-k1_tot + print('total correspondences for analysis: ' + str(k.shape[0])) + # estimate stitching translation with median displacement + km = np.median(k,axis=0) + print('median pixel displacements:' + str(km)) + kms[i] = km + # display scatter of displacements around median estimate + if scatter: + plt.scatter(k[:,0]-km[0],k[:,1]-km[1],s=1) + plt.xlim((-5,5)) + plt.ylim((-5,5)) + plt.show() + # identify slice index with most correspondences + j_slices[i] = ji - nx*dx + dx*np.argmax(good) + siftklist.append((k1_tot,k2_tot,k2slice)) + else: + print("not enough correspondences for strip " + str(i)) + siftklist.append(None) + + trest = np.zeros((kms.shape[0],3)) + trest[:,0] = kms[:,1] + trest[:,1] = kms[:,0] + trest[:,2] = j_slices - j_slice + + siftstitch = [[],[]] + for i,s in enumerate(siftklist): + zi = z0 + yi = j_slices[i] + for ii,t in enumerate(siftstitch): + if not s is None: + kzyx = np.empty((s[ii].shape[0],3)) + kzyx[:,0] = s[ii][:,0] + zi + kzyx[:,2] = i_slice if ii == 0 else yi + kzyx[:,1] = s[ii][:,1] + t.append(kzyx) + else: + t.append(None) + return siftstitch,trest \ No newline at end of file diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py new file mode 100644 index 0000000..bf14b24 --- /dev/null +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +""" +Created on Mon Jul 31 13:39:58 2023 + +@author: kevint +""" + +from sift_stitch import SiftDetector +from zarrutils import get_group_from_src + +def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,stitch_kwargs=None): + p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist] + q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist] + sd = SiftDetector(**sift_kwargs) + p_ptlist,q_ptlist = sd.stitch_over_segments("zx",p_datasets,q_datasets,**stitch_kwargs) # zstarts, zlength, i_slice, j_slice, ny, dy) + pmlist = [] + for p_src,q_src,p_pts,q_pts in zip(p_srclist,q_srclist,p_ptlist,q_ptlist): + pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":p_pts,"q_pts":q_pts}) + return pmlist \ No newline at end of file diff --git a/acpreprocessing/stitching_modules/acstitch/utils.py b/acpreprocessing/stitching_modules/acstitch/utils.py new file mode 100644 index 0000000..8489c3a --- /dev/null +++ b/acpreprocessing/stitching_modules/acstitch/utils.py @@ -0,0 +1,63 @@ +import numpy as np +import os +import zarr +import cv2 as cv +import json +import gzip + +def get_lists(basePath,baseName,acqName,segNum,centerTile,xIndList,yIndList,xShift=[],yShift=[]): + fpaths = [] + shifts = [] + tiles = [] + for i,x in enumerate(xIndList): + for j,y in enumerate(yIndList): + if not (x==centerTile[0] and y==centerTile[1]): + tstr = str(x) + '_' + str(y) + '_' + str(segNum) + runDir = baseName + '_' + acqName + '_' + str(x) + stripDir = acqName + '_' + str(x) + '_Pos' + str(y) + tifName = acqName + '_' + tstr + '.tif' + fpaths.append(os.path.join(basePath,runDir,stripDir,tifName)) + tiles.append(tstr) + if not xShift: + m = int(centerTile[0]-x) + else: + m = xShift[i] + if not yShift: + n = int(centerTile[1]-y) + else: + n = yShift[j] + shifts.append([m,n]) + return fpaths,tiles,shifts + +def get_dataset(zpath,grpname,miplvl=0): + zf = zarr.open(zpath) + return zf[grpname][miplvl] + +def get_cc_windows(imgRef,img=None,pixShifts=[],support=[0.1,0.1]): + clahe = cv.createCLAHE(clipLimit=5.0,tileGridSize=(8,8)) + if img is None: + img = imgRef + dshape = [int(np.around(d*support[i])) for i,d in enumerate(imgRef.shape)] + start = [int(np.floor((d-dshape[i])/2)) for i,d in enumerate(imgRef.shape)] + ccRef = imgRef[start[0]:start[0]+dshape[0],start[1]:start[1]+dshape[1]] + ccRef = clahe.apply(ccRef) + if not len(pixShifts)==2: + pixShifts = [0,0] + ss = [0,0] + indShifts = pixShifts[::-1] + for i in range(2): + if start[i]+indShifts[i]-dshape[i]<0: + ss[i] = 0 + elif start[i]+indShifts[i]+2*dshape[i]>= img.shape[i]: + ss[i] = img.shape[i]-3*dshape[i] + else: + ss[i] = start[i] + indShifts[i] - dshape[i] + ccTarget = img[ss[0]:ss[0]+3*dshape[0],ss[1]:ss[1]+3*dshape[1]] + ccTarget = clahe.apply(ccTarget) + return ccRef,ccTarget + +def get_pixel_shift_xy(ij,resultShape): + cij = (np.array(resultShape)-1)/2 + xy_shift = (np.array(ij) - cij)[::-1].astype(int) + return xy_shift + diff --git a/acpreprocessing/stitching_modules/acstitch/zarrutils.py b/acpreprocessing/stitching_modules/acstitch/zarrutils.py new file mode 100644 index 0000000..a1e6288 --- /dev/null +++ b/acpreprocessing/stitching_modules/acstitch/zarrutils.py @@ -0,0 +1,16 @@ +import os +import zarr + +def get_zarr_group(zpath,grpname): + # key to working with zarr files + # group contains mip datasets and dataset attributes + zf = zarr.open(zpath) + return zf[grpname] + +def get_group_from_src(srcpath): + # returns zarr group given a neuroglancer source path + # used to get datasets from neuroglancer layer json + pathout = 'zarr://http://bigkahuna.corp.alleninstitute.org/ACdata' # Url for ACdata for NG hosted on BigKahuna + pathin = 'J:' # Local path to ACdata + s = os.path.split(srcpath.replace(pathout,pathin)) + return get_zarr_group(zpath=s[0],grpname=s[1]) \ No newline at end of file From 67d27d6e971e11616f66aed927f9b77d6e37ebd8 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Mon, 31 Jul 2023 15:23:28 -0700 Subject: [PATCH 02/37] fixing import paths --- acpreprocessing/stitching_modules/acstitch/stitch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index bf14b24..fa5a381 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -5,8 +5,8 @@ @author: kevint """ -from sift_stitch import SiftDetector -from zarrutils import get_group_from_src +from acpreprocessing.stitching_modules.acstitch.sift_stitch import SiftDetector +from acpreprocessing.stitching_modules.acstitch.zarrutils import get_group_from_src def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,stitch_kwargs=None): p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist] From 3e34d34057bcf9a95e67dda2810cf30482423945 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Mon, 31 Jul 2023 20:01:46 -0700 Subject: [PATCH 03/37] json encoding of ndarrays --- acpreprocessing/stitching_modules/acstitch/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/acpreprocessing/stitching_modules/acstitch/io.py b/acpreprocessing/stitching_modules/acstitch/io.py index dd1762f..c90aab2 100644 --- a/acpreprocessing/stitching_modules/acstitch/io.py +++ b/acpreprocessing/stitching_modules/acstitch/io.py @@ -18,7 +18,7 @@ def default(self, obj): def save_pointmatch_file(pmdata,jsonpath): with gzip.open(jsonpath, 'w') as fout: - fout.write(json.dumps(pmdata).encode('utf-8')) + fout.write(json.dumps(pmdata,cls=NumpyEncoder).encode('utf-8')) def read_pointmatch_file(jsonpath): From f6fd08e49566e1511a854368b3388292bde2092c Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Mon, 31 Jul 2023 20:09:23 -0700 Subject: [PATCH 04/37] fixing pointmatch file io --- acpreprocessing/stitching_modules/acstitch/io.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/acpreprocessing/stitching_modules/acstitch/io.py b/acpreprocessing/stitching_modules/acstitch/io.py index c90aab2..365fa78 100644 --- a/acpreprocessing/stitching_modules/acstitch/io.py +++ b/acpreprocessing/stitching_modules/acstitch/io.py @@ -24,4 +24,8 @@ def save_pointmatch_file(pmdata,jsonpath): def read_pointmatch_file(jsonpath): with gzip.open(jsonpath, 'r') as fin: data = json.loads(fin.read().decode('utf-8')) + if data: + for tspec in data: + for key in ["p_pts","q_pts"]: + tspec[key] = numpy.asarray(tspec[key]) return data \ No newline at end of file From 16ab05889105c236e88d024d0c75b36dec181472 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Mon, 31 Jul 2023 20:52:03 -0700 Subject: [PATCH 05/37] fix handling of None output --- acpreprocessing/stitching_modules/acstitch/sift_stitch.py | 3 ++- acpreprocessing/stitching_modules/acstitch/stitch.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py index a7026e4..d13e8fe 100644 --- a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py @@ -117,7 +117,8 @@ def stitch_over_segments(self,slice_axes,p_dslist,q_dslist,zstarts,zlength,**kwa for pm in pmlist: for i in range(2): for ii in range(n_tiles): - tile_pmlist[i][ii].append(pm[i][ii]) + if not pm[i][ii] is None: + tile_pmlist[i][ii].append(pm[i][ii]) p_ptlist = [np.concatenate(pm) for pm in tile_pmlist[0]] q_ptlist = [np.concatenate(pm) for pm in tile_pmlist[1]] return p_ptlist,q_ptlist diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index fa5a381..6cff994 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -14,6 +14,7 @@ def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,sti sd = SiftDetector(**sift_kwargs) p_ptlist,q_ptlist = sd.stitch_over_segments("zx",p_datasets,q_datasets,**stitch_kwargs) # zstarts, zlength, i_slice, j_slice, ny, dy) pmlist = [] - for p_src,q_src,p_pts,q_pts in zip(p_srclist,q_srclist,p_ptlist,q_ptlist): - pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":p_pts,"q_pts":q_pts}) + if not p_ptlist is None: + for p_src,q_src,p_pts,q_pts in zip(p_srclist,q_srclist,p_ptlist,q_ptlist): + pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":p_pts,"q_pts":q_pts}) return pmlist \ No newline at end of file From a6a27b521ff7a953a4aaf8f961ad8d3e9ce41545 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Mon, 31 Jul 2023 20:52:55 -0700 Subject: [PATCH 06/37] remove print --- acpreprocessing/stitching_modules/acstitch/sift_stitch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py index d13e8fe..2c84b54 100644 --- a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py @@ -36,7 +36,7 @@ def detect_matches(self,kp1,des1,cimg1,img,draw=False): if m.distance < self.ratio*n.distance: matchesMask[i]=[1,0] good.append(m) - print(len(good)) + #print(len(good)) k1xy = np.array([np.array(k.pt) for k in kp1]) k2xy = np.array([np.array(k.pt) for k in kp2]) k1 = [] From cf45746a59b4319851929eb88285ebb56d524145 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Mon, 31 Jul 2023 23:05:34 -0700 Subject: [PATCH 07/37] handle empty lists --- acpreprocessing/stitching_modules/acstitch/sift_stitch.py | 4 ++-- acpreprocessing/stitching_modules/acstitch/stitch.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py index 2c84b54..b5b6949 100644 --- a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py @@ -119,8 +119,8 @@ def stitch_over_segments(self,slice_axes,p_dslist,q_dslist,zstarts,zlength,**kwa for ii in range(n_tiles): if not pm[i][ii] is None: tile_pmlist[i][ii].append(pm[i][ii]) - p_ptlist = [np.concatenate(pm) for pm in tile_pmlist[0]] - q_ptlist = [np.concatenate(pm) for pm in tile_pmlist[1]] + p_ptlist = [np.concatenate(pm) if pm else [] for pm in tile_pmlist[0]] + q_ptlist = [np.concatenate(pm) if pm else [] for pm in tile_pmlist[1]] return p_ptlist,q_ptlist else: return None,None diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 6cff994..ab11906 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -16,5 +16,6 @@ def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,sti pmlist = [] if not p_ptlist is None: for p_src,q_src,p_pts,q_pts in zip(p_srclist,q_srclist,p_ptlist,q_ptlist): - pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":p_pts,"q_pts":q_pts}) + if not p_pts is None and len(p_pts) > 0: + pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":p_pts,"q_pts":q_pts}) return pmlist \ No newline at end of file From 56938fe97c7e074cc1b0279fd31c7ffe5580051d Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Tue, 1 Aug 2023 12:41:48 -0700 Subject: [PATCH 08/37] adding ccorr support --- .../acstitch/ccorr_stitch.py | 21 +++++ .../stitching_modules/acstitch/rtccorr.py | 80 +++++++++++++++++++ .../stitching_modules/acstitch/stitch.py | 59 +++++++++++++- 3 files changed, 158 insertions(+), 2 deletions(-) create mode 100644 acpreprocessing/stitching_modules/acstitch/ccorr_stitch.py create mode 100644 acpreprocessing/stitching_modules/acstitch/rtccorr.py diff --git a/acpreprocessing/stitching_modules/acstitch/ccorr_stitch.py b/acpreprocessing/stitching_modules/acstitch/ccorr_stitch.py new file mode 100644 index 0000000..4a3c258 --- /dev/null +++ b/acpreprocessing/stitching_modules/acstitch/ccorr_stitch.py @@ -0,0 +1,21 @@ +import numpy +from acpreprocessing.stitching_modules.acstitch.rtccorr import get_point_correspondence + + +def get_correspondences(A1_ds,A2_ds,A1_pts,A2_pts,w,pad=False): + w = numpy.asarray(w,dtype=int) + if len(A1_pts.shape)<2: + A1_pts = numpy.array([A1_pts]) + A2_pts = numpy.array([A2_pts]) + pm1 = [] + pm2 = [] + for p,q in zip(A1_pts.astype(int),A2_pts.astype(int)): + A2sub = A2_ds[0,0,(q-w)[0]:(q+w)[0],(q-w)[1]:(q+w)[1],(q-w)[2]:(q+w)[2]] + A1sub = A1_ds[0,0,(p-w)[0]:(p+w)[0],(p-w)[1]:(p+w)[1],(p-w)[2]:(p+w)[2]] + p1,p2 = get_point_correspondence(p,q,A1sub,A2sub,autocorrelation_threshold=0.8,padarray=pad,value_threshold=1000) + if not p1 is None: + pm1.append(p1) + pm2.append(p2) + if pm1: + return numpy.asarray(pm1),numpy.asarray(pm2) + return None,None \ No newline at end of file diff --git a/acpreprocessing/stitching_modules/acstitch/rtccorr.py b/acpreprocessing/stitching_modules/acstitch/rtccorr.py new file mode 100644 index 0000000..502bb7b --- /dev/null +++ b/acpreprocessing/stitching_modules/acstitch/rtccorr.py @@ -0,0 +1,80 @@ + +import numpy +import scipy.ndimage + +def correlate_fftns(fft1, fft2): + prod = fft1 * fft2.conj() + res = numpy.fft.ifftn(prod) + + corr = numpy.fft.fftshift(res).real + return corr + + +def ccorr_fftn(img1, img2): + # TODO do we want to pad this? + fft1 = numpy.fft.fftn(img1) + fft2 = numpy.fft.fftn(img2) + + return correlate_fftns(fft1, fft2) + + +def autocorr_fftn(img): + fft = numpy.fft.fftn(img) + return correlate_fftns(fft, fft) + + +def ccorr_and_autocorr_fftn(img1, img2): + # TODO do we want to pad this? + fft1 = numpy.fft.fftn(img1) + fft2 = numpy.fft.fftn(img2) + ccorr = correlate_fftns(fft1, fft2) + acorr1 = correlate_fftns(fft1, fft1) + acorr2 = correlate_fftns(fft2, fft2) + return ccorr, acorr1, acorr2 + + +def subpixel_maximum(arr): + max_loc = numpy.unravel_index(numpy.argmax(arr), arr.shape) + + sub_arr = arr[ + tuple(slice(ml-1, ml+2) for ml in max_loc) + ] + + # get center of mass of sub_arr + subpixel_max_loc = numpy.array(scipy.ndimage.center_of_mass(sub_arr)) - 1 + return subpixel_max_loc + max_loc + + +def ccorr_disp(img1, img2, autocorrelation_threshold=None, padarray=False, value_threshold=0): + if padarray: + d = numpy.ceil(numpy.array(img1.shape) / 2) + pw = numpy.asarray([(di,di) for di in d],dtype=int) + img1 = numpy.pad(img1,pw) + img2 = numpy.pad(img2,pw) + if value_threshold: + img1[img1 0) and (not numpy.isnan(ac2max) and ac2max > 0): + autocorrelation_ratio = cc.max() / (numpy.sqrt(ac1max*ac2max)) + print(autocorrelation_ratio) + if autocorrelation_ratio < autocorrelation_threshold: + # what to do here? + return None + else: + return None + else: + cc = ccorr_fftn(img1, img2) + max_loc = subpixel_maximum(cc) + mid_point = numpy.array(img1.shape) // 2 + return max_loc - mid_point + + +def get_point_correspondence(src_pt, dst_pt, src_patch, dst_patch, autocorrelation_threshold=0.8,padarray=False,value_threshold=0): + disp = ccorr_disp(src_patch, dst_patch, autocorrelation_threshold, padarray,value_threshold) + if disp is not None: + return src_pt, dst_pt - disp + return None,None \ No newline at end of file diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index ab11906..1a2b729 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -4,9 +4,12 @@ @author: kevint """ - +import numpy from acpreprocessing.stitching_modules.acstitch.sift_stitch import SiftDetector +from acpreprocessing.stitching_modules.acstitch.ccorr_stitch import get_correspondences from acpreprocessing.stitching_modules.acstitch.zarrutils import get_group_from_src +from acpreprocessing.stitching_modules.acstitch.io import read_pointmatch_file + def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,stitch_kwargs=None): p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist] @@ -18,4 +21,56 @@ def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,sti for p_src,q_src,p_pts,q_pts in zip(p_srclist,q_srclist,p_ptlist,q_ptlist): if not p_pts is None and len(p_pts) > 0: pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":p_pts,"q_pts":q_pts}) - return pmlist \ No newline at end of file + return pmlist + + +def generate_ccorr_pointmatches(p_srclist,q_srclist,miplvl=0,ccorr_kwargs=None,stitch_kwargs=None): + if "sift_pointmatch_file" in ccorr_kwargs and ccorr_kwargs["sift_pointmatch_file"]: + sift_pmlist = read_pointmatch_file(ccorr_kwargs["sift_pointmatch_file"]) + else: + sift_pmlist = None + p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist] + q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist] + pmlist = [] + for i in range(len(p_datasets)): + pds = p_datasets[i] + qds = q_datasets[i] + if not sift_pmlist is None: + ppts,qpts = run_ccorr_with_sift_points(pds, qds, sift_pmlist[i]["p_pts"].astype(int), sift_pmlist[i]["q_pts"].astype(int), **ccorr_kwargs) + else: + ppts,qpts = run_ccorr(**ccorr_kwargs) + if not ppts is None and len(ppts) > 0: + pmlist.append({"p_tile":p_srclist[i],"q_tile":q_srclist[i],"p_pts":ppts,"q_pts":qpts}) + return pmlist + + +def run_ccorr_with_sift_points(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_w=[32,32,32],pad_array=False,axis_shift=[0,0,0],axis_range=None,**kwargs): + p_pts,q_pts = get_cc_points_from_sift(p_ds, q_ds, p_siftpts, q_siftpts,n_cc_pts,axis_shift,axis_range) + ppm,qpm = get_correspondences(p_ds,q_ds,p_pts,q_pts,numpy.asarray(axis_w),pad_array) + return ppm,qpm + +def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift=[0,0,0],axis_range=None): + if axis_range is None: + axis_range = [[] for i in range(p_siftpts.shape[1])] + zstarts = numpy.linspace(numpy.min(p_siftpts,axis=0),numpy.max(p_siftpts,axis=0),n_cc_pts+1) + p_pts = numpy.empty((n_cc_pts,3),dtype=int) + q_pts = numpy.empty((n_cc_pts,3),dtype=int) + for i in range(n_cc_pts): + if len(axis_range[0]) == 0: + axis_range[0] = [zstarts[i],zstarts[i+1]] + r = numpy.full(p_siftpts.shape,True) + for i,a in enumerate(axis_range): + if a: + r = r & ((p_siftpts[:,i]>=a[0]) & (p_siftpts[:,i]<=a[1])) + pr = p_siftpts[r] + imax = numpy.argmax(p_ds[0,0,pr[:,0],pr[:,1],pr[:,2]]) + ppt = pr[imax,:] + print(p_ds[0,0,ppt[0],ppt[1],ppt[2]]) + p_pts[i] = ppt + q_pts[i] = ppt + numpy.array(axis_shift) + return p_pts,q_pts + + + +def run_ccorr(**kwargs): + pass \ No newline at end of file From bf6acb7691476427fbe709f65d9e37eea76e432e Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Tue, 1 Aug 2023 13:00:01 -0700 Subject: [PATCH 09/37] debugging --- acpreprocessing/stitching_modules/acstitch/stitch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 1a2b729..8b75e1c 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -50,6 +50,7 @@ def run_ccorr_with_sift_points(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_w=[ return ppm,qpm def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift=[0,0,0],axis_range=None): + # TODO: handle overly granular bins with potentially 0 sift points returned if axis_range is None: axis_range = [[] for i in range(p_siftpts.shape[1])] zstarts = numpy.linspace(numpy.min(p_siftpts,axis=0),numpy.max(p_siftpts,axis=0),n_cc_pts+1) @@ -60,6 +61,7 @@ def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift= axis_range[0] = [zstarts[i],zstarts[i+1]] r = numpy.full(p_siftpts.shape,True) for i,a in enumerate(axis_range): + print(a) if a: r = r & ((p_siftpts[:,i]>=a[0]) & (p_siftpts[:,i]<=a[1])) pr = p_siftpts[r] From c4e01c504b0c9a564a79f841b8add141485dc531 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Tue, 1 Aug 2023 14:03:41 -0700 Subject: [PATCH 10/37] fixing ccorr points from sift --- acpreprocessing/stitching_modules/acstitch/stitch.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 8b75e1c..a23c5e9 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -53,17 +53,18 @@ def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift= # TODO: handle overly granular bins with potentially 0 sift points returned if axis_range is None: axis_range = [[] for i in range(p_siftpts.shape[1])] - zstarts = numpy.linspace(numpy.min(p_siftpts,axis=0),numpy.max(p_siftpts,axis=0),n_cc_pts+1) + if len(axis_range[0]) == 0: + zstarts = numpy.linspace(numpy.min(p_siftpts[:,0]),numpy.max(p_siftpts[:,0]),n_cc_pts+1) p_pts = numpy.empty((n_cc_pts,3),dtype=int) q_pts = numpy.empty((n_cc_pts,3),dtype=int) for i in range(n_cc_pts): - if len(axis_range[0]) == 0: - axis_range[0] = [zstarts[i],zstarts[i+1]] r = numpy.full(p_siftpts.shape,True) for i,a in enumerate(axis_range): print(a) - if a: + if len(a)>0: r = r & ((p_siftpts[:,i]>=a[0]) & (p_siftpts[:,i]<=a[1])) + elif i == 0: + r = r & ((p_siftpts[:,i]>=zstarts[i]) & (p_siftpts[:,i]<=zstarts[i+1])) pr = p_siftpts[r] imax = numpy.argmax(p_ds[0,0,pr[:,0],pr[:,1],pr[:,2]]) ppt = pr[imax,:] From 3e6276e518200fd82ed462b3b0f428b0abb43f27 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Tue, 1 Aug 2023 14:05:41 -0700 Subject: [PATCH 11/37] fixing bugs --- acpreprocessing/stitching_modules/acstitch/stitch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index a23c5e9..e5fbf14 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -58,7 +58,7 @@ def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift= p_pts = numpy.empty((n_cc_pts,3),dtype=int) q_pts = numpy.empty((n_cc_pts,3),dtype=int) for i in range(n_cc_pts): - r = numpy.full(p_siftpts.shape,True) + r = numpy.full(p_siftpts.shape[0],True) for i,a in enumerate(axis_range): print(a) if len(a)>0: From 8f1e05123b326fbc3230a323b86e3ac041176065 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Tue, 1 Aug 2023 14:39:52 -0700 Subject: [PATCH 12/37] bug fixes --- acpreprocessing/stitching_modules/acstitch/stitch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index e5fbf14..1d7dad7 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -46,6 +46,8 @@ def generate_ccorr_pointmatches(p_srclist,q_srclist,miplvl=0,ccorr_kwargs=None,s def run_ccorr_with_sift_points(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_w=[32,32,32],pad_array=False,axis_shift=[0,0,0],axis_range=None,**kwargs): p_pts,q_pts = get_cc_points_from_sift(p_ds, q_ds, p_siftpts, q_siftpts,n_cc_pts,axis_shift,axis_range) + print(p_pts) + print(q_pts) ppm,qpm = get_correspondences(p_ds,q_ds,p_pts,q_pts,numpy.asarray(axis_w),pad_array) return ppm,qpm @@ -60,13 +62,13 @@ def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift= for i in range(n_cc_pts): r = numpy.full(p_siftpts.shape[0],True) for i,a in enumerate(axis_range): - print(a) if len(a)>0: r = r & ((p_siftpts[:,i]>=a[0]) & (p_siftpts[:,i]<=a[1])) elif i == 0: r = r & ((p_siftpts[:,i]>=zstarts[i]) & (p_siftpts[:,i]<=zstarts[i+1])) pr = p_siftpts[r] imax = numpy.argmax(p_ds[0,0,pr[:,0],pr[:,1],pr[:,2]]) + print(imax) ppt = pr[imax,:] print(p_ds[0,0,ppt[0],ppt[1],ppt[2]]) p_pts[i] = ppt From 963975f284f1b716cec3cf404ccbfefcd56e6246 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Tue, 1 Aug 2023 14:46:09 -0700 Subject: [PATCH 13/37] bug fixes --- acpreprocessing/stitching_modules/acstitch/stitch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 1d7dad7..509e83f 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -61,11 +61,11 @@ def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift= q_pts = numpy.empty((n_cc_pts,3),dtype=int) for i in range(n_cc_pts): r = numpy.full(p_siftpts.shape[0],True) - for i,a in enumerate(axis_range): + for ai,a in enumerate(axis_range): if len(a)>0: - r = r & ((p_siftpts[:,i]>=a[0]) & (p_siftpts[:,i]<=a[1])) - elif i == 0: - r = r & ((p_siftpts[:,i]>=zstarts[i]) & (p_siftpts[:,i]<=zstarts[i+1])) + r = r & ((p_siftpts[:,ai]>=a[0]) & (p_siftpts[:,ai]<=a[1])) + elif ai == 0: + r = r & ((p_siftpts[:,ai]>=zstarts[i]) & (p_siftpts[:,ai]<=zstarts[i+1])) pr = p_siftpts[r] imax = numpy.argmax(p_ds[0,0,pr[:,0],pr[:,1],pr[:,2]]) print(imax) From e6c4f3d17ee467df92bb871d165746a63f043d6f Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Tue, 1 Aug 2023 15:13:14 -0700 Subject: [PATCH 14/37] cleaning print statements and debugging --- acpreprocessing/stitching_modules/acstitch/rtccorr.py | 2 +- acpreprocessing/stitching_modules/acstitch/stitch.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/rtccorr.py b/acpreprocessing/stitching_modules/acstitch/rtccorr.py index 502bb7b..9c8c182 100644 --- a/acpreprocessing/stitching_modules/acstitch/rtccorr.py +++ b/acpreprocessing/stitching_modules/acstitch/rtccorr.py @@ -60,9 +60,9 @@ def ccorr_disp(img1, img2, autocorrelation_threshold=None, padarray=False, value ac2max = ac2.max() if (not numpy.isnan(ac1max) and ac1max > 0) and (not numpy.isnan(ac2max) and ac2max > 0): autocorrelation_ratio = cc.max() / (numpy.sqrt(ac1max*ac2max)) - print(autocorrelation_ratio) if autocorrelation_ratio < autocorrelation_threshold: # what to do here? + print("ratio below threshold: " + str(autocorrelation_ratio)) return None else: return None diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 509e83f..e14255b 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -46,8 +46,6 @@ def generate_ccorr_pointmatches(p_srclist,q_srclist,miplvl=0,ccorr_kwargs=None,s def run_ccorr_with_sift_points(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_w=[32,32,32],pad_array=False,axis_shift=[0,0,0],axis_range=None,**kwargs): p_pts,q_pts = get_cc_points_from_sift(p_ds, q_ds, p_siftpts, q_siftpts,n_cc_pts,axis_shift,axis_range) - print(p_pts) - print(q_pts) ppm,qpm = get_correspondences(p_ds,q_ds,p_pts,q_pts,numpy.asarray(axis_w),pad_array) return ppm,qpm @@ -68,9 +66,7 @@ def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift= r = r & ((p_siftpts[:,ai]>=zstarts[i]) & (p_siftpts[:,ai]<=zstarts[i+1])) pr = p_siftpts[r] imax = numpy.argmax(p_ds[0,0,pr[:,0],pr[:,1],pr[:,2]]) - print(imax) - ppt = pr[imax,:] - print(p_ds[0,0,ppt[0],ppt[1],ppt[2]]) + ppt = pr[imax,:] p_pts[i] = ppt q_pts[i] = ppt + numpy.array(axis_shift) return p_pts,q_pts From 94059af1753068af81add9ac21b31b80268d9e97 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Tue, 1 Aug 2023 15:17:02 -0700 Subject: [PATCH 15/37] print iteration --- acpreprocessing/stitching_modules/acstitch/stitch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index e14255b..3b058fa 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -33,6 +33,7 @@ def generate_ccorr_pointmatches(p_srclist,q_srclist,miplvl=0,ccorr_kwargs=None,s q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist] pmlist = [] for i in range(len(p_datasets)): + print("computing pointmatches for source pair " + str(i)) pds = p_datasets[i] qds = q_datasets[i] if not sift_pmlist is None: From fa0c56654d9e88bcb9a97fd2e4070429037c5ef0 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Sun, 6 Aug 2023 09:30:58 -0700 Subject: [PATCH 16/37] handle no sift points in range --- acpreprocessing/stitching_modules/acstitch/stitch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 3b058fa..347bc85 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -66,8 +66,11 @@ def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift= elif ai == 0: r = r & ((p_siftpts[:,ai]>=zstarts[i]) & (p_siftpts[:,ai]<=zstarts[i+1])) pr = p_siftpts[r] - imax = numpy.argmax(p_ds[0,0,pr[:,0],pr[:,1],pr[:,2]]) - ppt = pr[imax,:] + if len(pr) > 0: + imax = numpy.argmax(p_ds[0,0,pr[:,0],pr[:,1],pr[:,2]]) + ppt = pr[imax,:] + else: + ppt = numpy.array([(zstarts[i]+zstarts[i+1])/2,numpy.mean(p_siftpts[:,1]),numpy.mean(p_siftpts[:,2])],dtype=int) p_pts[i] = ppt q_pts[i] = ppt + numpy.array(axis_shift) return p_pts,q_pts From 64ec1b7665921b364e57311f0d211c399ecc4b17 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Sun, 6 Aug 2023 16:31:52 -0700 Subject: [PATCH 17/37] handle Nones --- acpreprocessing/stitching_modules/acstitch/io.py | 3 ++- acpreprocessing/stitching_modules/acstitch/stitch.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/acpreprocessing/stitching_modules/acstitch/io.py b/acpreprocessing/stitching_modules/acstitch/io.py index 365fa78..3d1e014 100644 --- a/acpreprocessing/stitching_modules/acstitch/io.py +++ b/acpreprocessing/stitching_modules/acstitch/io.py @@ -27,5 +27,6 @@ def read_pointmatch_file(jsonpath): if data: for tspec in data: for key in ["p_pts","q_pts"]: - tspec[key] = numpy.asarray(tspec[key]) + if not tspec[key] is None: + tspec[key] = numpy.asarray(tspec[key]) return data \ No newline at end of file diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 347bc85..1d3c77c 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -21,6 +21,8 @@ def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,sti for p_src,q_src,p_pts,q_pts in zip(p_srclist,q_srclist,p_ptlist,q_ptlist): if not p_pts is None and len(p_pts) > 0: pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":p_pts,"q_pts":q_pts}) + else: + pmlist.append({"p_tile":p_src,"q_tile":q_src,"p_pts":None,"q_pts":None}) return pmlist From 1baa12de8dd48753a7d8c79282fbd24d2e5a05d6 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Sun, 6 Aug 2023 16:37:57 -0700 Subject: [PATCH 18/37] handle None in sift list --- acpreprocessing/stitching_modules/acstitch/stitch.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 1d3c77c..9d669ba 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -39,11 +39,17 @@ def generate_ccorr_pointmatches(p_srclist,q_srclist,miplvl=0,ccorr_kwargs=None,s pds = p_datasets[i] qds = q_datasets[i] if not sift_pmlist is None: - ppts,qpts = run_ccorr_with_sift_points(pds, qds, sift_pmlist[i]["p_pts"].astype(int), sift_pmlist[i]["q_pts"].astype(int), **ccorr_kwargs) + if not sift_pmlist[i]["p_pts"] is None and len(sift_pmlist[i]["p_pts"])>0: + ppts,qpts = run_ccorr_with_sift_points(pds, qds, sift_pmlist[i]["p_pts"].astype(int), sift_pmlist[i]["q_pts"].astype(int), **ccorr_kwargs) + else: + ppts = None + qpts = None else: ppts,qpts = run_ccorr(**ccorr_kwargs) if not ppts is None and len(ppts) > 0: pmlist.append({"p_tile":p_srclist[i],"q_tile":q_srclist[i],"p_pts":ppts,"q_pts":qpts}) + else: + pmlist.append({"p_tile":p_srclist[i],"q_tile":q_srclist[i],"p_pts":None,"q_pts":None}) return pmlist From c3a8123e618cc8cf84051fe3f1d8353713a00c19 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Mon, 7 Aug 2023 19:43:43 -0700 Subject: [PATCH 19/37] fixing zy sift stitching --- .../stitching_modules/acstitch/sift_stitch.py | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py index b5b6949..2f3eddf 100644 --- a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py @@ -178,13 +178,12 @@ def run_zx_stitch(self, t.append(None) return siftstitch - def run_zy_stitch(sift_detector, + def run_zy_stitch(self, p_srclist, q_srclist, z0,z1,i_slice,j_slice,nx,dx, scatter=False): # estimate translation between strips with the median 2D displacement of matching point correspondences - kms = np.zeros((len(p_srclist),2)) siftklist = [] j_slices = np.ones(len(p_srclist),dtype=int)*j_slice for i,dsRef in enumerate(p_srclist): @@ -193,18 +192,17 @@ def run_zy_stitch(sift_detector, dsStack = q_srclist[i] imgRef = dsRef[0,0,z0:z1,:,i_slice] # detect SIFT keypoints for reference slice - kp1, des1, cimg1 = sift_detector.detect_keypoints(imgRef) + kp1, des1, cimg1 = self.detect_keypoints(imgRef) imgStack = dsStack[0,0,z0:z1,:,(ji-nx*dx):(ji+(nx+1)*dx):dx] # detect correspondences in slices from neighboring strip - k1_tot,k2_tot,good,k2slice = sift_detector.detect_and_combine(kp1,des1,cimg1,imgStack,False,axis=2,max_only=False) + k1_tot,k2_tot,good,k2slice = self.detect_and_combine(kp1,des1,cimg1,imgStack,False,axis=2,max_only=True) print("Number of correspondences: " + str(good)) - if not k1_tot is None and k1_tot.shape[0]>50: + if not k1_tot is None and k1_tot.shape[0]>100: k = k2_tot-k1_tot print('total correspondences for analysis: ' + str(k.shape[0])) # estimate stitching translation with median displacement km = np.median(k,axis=0) print('median pixel displacements:' + str(km)) - kms[i] = km # display scatter of displacements around median estimate if scatter: plt.scatter(k[:,0]-km[0],k[:,1]-km[1],s=1) @@ -217,23 +215,18 @@ def run_zy_stitch(sift_detector, else: print("not enough correspondences for strip " + str(i)) siftklist.append(None) - - trest = np.zeros((kms.shape[0],3)) - trest[:,0] = kms[:,1] - trest[:,1] = kms[:,0] - trest[:,2] = j_slices - j_slice - + siftstitch = [[],[]] for i,s in enumerate(siftklist): zi = z0 - yi = j_slices[i] + xi = j_slices[i] for ii,t in enumerate(siftstitch): if not s is None: kzyx = np.empty((s[ii].shape[0],3)) kzyx[:,0] = s[ii][:,0] + zi - kzyx[:,2] = i_slice if ii == 0 else yi kzyx[:,1] = s[ii][:,1] + kzyx[:,2] = i_slice if ii == 0 else xi t.append(kzyx) else: t.append(None) - return siftstitch,trest \ No newline at end of file + return siftstitch \ No newline at end of file From 99131a2bacb62b33d1b75408dca40648f90d276f Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Mon, 7 Aug 2023 19:45:43 -0700 Subject: [PATCH 20/37] adding zy to stitching --- acpreprocessing/stitching_modules/acstitch/sift_stitch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py index 2f3eddf..61cab5e 100644 --- a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py @@ -111,6 +111,10 @@ def stitch_over_segments(self,slice_axes,p_dslist,q_dslist,zstarts,zlength,**kwa for zs in zstarts: seglist = self.run_zx_stitch(p_dslist,q_dslist,zs,zs+zlength,**kwargs) pmlist.append(seglist) + elif slice_axes == "zy": + for zs in zstarts: + seglist = self.run_zy_stitch(p_dslist,q_dslist,zs,zs+zlength,**kwargs) + pmlist.append(seglist) if pmlist: n_tiles = len(pmlist[0][0]) tile_pmlist = [[[] for i in range(n_tiles)],[[] for i in range(n_tiles)]] From 0948ff7a0d0159fc38c40a5262cfb51611875989 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Mon, 7 Aug 2023 19:47:32 -0700 Subject: [PATCH 21/37] stitch axes --- acpreprocessing/stitching_modules/acstitch/stitch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 9d669ba..4342eef 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -11,11 +11,11 @@ from acpreprocessing.stitching_modules.acstitch.io import read_pointmatch_file -def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,stitch_kwargs=None): +def generate_sift_pointmatches(p_srclist,q_srclist,stitch_axes="zx",miplvl=0,sift_kwargs=None,stitch_kwargs=None): p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist] q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist] sd = SiftDetector(**sift_kwargs) - p_ptlist,q_ptlist = sd.stitch_over_segments("zx",p_datasets,q_datasets,**stitch_kwargs) # zstarts, zlength, i_slice, j_slice, ny, dy) + p_ptlist,q_ptlist = sd.stitch_over_segments(stitch_axes,p_datasets,q_datasets,**stitch_kwargs) # zstarts, zlength, i_slice, j_slice, ny, dy) pmlist = [] if not p_ptlist is None: for p_src,q_src,p_pts,q_pts in zip(p_srclist,q_srclist,p_ptlist,q_ptlist): From de1643c9a6f7c03582588d52d980f4a617fca881 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Wed, 9 Aug 2023 09:01:30 -0700 Subject: [PATCH 22/37] ccorr updates --- .../stitching_modules/acstitch/sift_stitch.py | 6 +++--- .../stitching_modules/acstitch/stitch.py | 13 +++++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py index 61cab5e..dce118b 100644 --- a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py @@ -105,13 +105,13 @@ def detect_and_combine(self,kp1,des1,cimg1,imgStack,draw=False,axis=2,max_only=F return None,None,None,None - def stitch_over_segments(self,slice_axes,p_dslist,q_dslist,zstarts,zlength,**kwargs): + def stitch_over_segments(self,p_dslist,q_dslist,zstarts,zlength,stitch_axes,**kwargs): pmlist = [] - if slice_axes == "zx": + if stitch_axes == "zx": for zs in zstarts: seglist = self.run_zx_stitch(p_dslist,q_dslist,zs,zs+zlength,**kwargs) pmlist.append(seglist) - elif slice_axes == "zy": + elif stitch_axes == "zy": for zs in zstarts: seglist = self.run_zy_stitch(p_dslist,q_dslist,zs,zs+zlength,**kwargs) pmlist.append(seglist) diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 4342eef..120d15d 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -11,11 +11,15 @@ from acpreprocessing.stitching_modules.acstitch.io import read_pointmatch_file -def generate_sift_pointmatches(p_srclist,q_srclist,stitch_axes="zx",miplvl=0,sift_kwargs=None,stitch_kwargs=None): +def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,stitch_kwargs=None): p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist] q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist] sd = SiftDetector(**sift_kwargs) - p_ptlist,q_ptlist = sd.stitch_over_segments(stitch_axes,p_datasets,q_datasets,**stitch_kwargs) # zstarts, zlength, i_slice, j_slice, ny, dy) + if "sift_pointmatch_file" in stitch_kwargs and stitch_kwargs["sift_pointmatch_file"]: + sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"]) + else: + sift_pmlist = None + p_ptlist,q_ptlist = sd.stitch_over_segments(p_datasets,q_datasets,**stitch_kwargs) # zstarts, zlength, i_slice, j_slice, ny, dy) pmlist = [] if not p_ptlist is None: for p_src,q_src,p_pts,q_pts in zip(p_srclist,q_srclist,p_ptlist,q_ptlist): @@ -27,8 +31,9 @@ def generate_sift_pointmatches(p_srclist,q_srclist,stitch_axes="zx",miplvl=0,sif def generate_ccorr_pointmatches(p_srclist,q_srclist,miplvl=0,ccorr_kwargs=None,stitch_kwargs=None): - if "sift_pointmatch_file" in ccorr_kwargs and ccorr_kwargs["sift_pointmatch_file"]: - sift_pmlist = read_pointmatch_file(ccorr_kwargs["sift_pointmatch_file"]) + if "sift_pointmatch_file" in stitch_kwargs and stitch_kwargs["sift_pointmatch_file"]: + print("running crosscorrelation with points from " + stitch_kwargs["sift_pointmatch_file"]) + sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"]) else: sift_pmlist = None p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist] From f6d1b7eb3e1783b869220a99db9717eae087772c Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Thu, 10 Aug 2023 20:10:13 -0700 Subject: [PATCH 23/37] stitching planes with sift pointmatches --- .../stitching_modules/acstitch/sift_stitch.py | 124 +++++++++++++++++- .../stitching_modules/acstitch/stitch.py | 8 +- 2 files changed, 128 insertions(+), 4 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py index dce118b..570b6f9 100644 --- a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py @@ -5,6 +5,45 @@ import cv2 as cv from matplotlib import pyplot as plt + +def generate_rois_from_pointmatches(pm_list,axis_range,roi_dims,**kwargs): + roilist = [] + for pm in pm_list: + p_siftpts = pm["p_pts"] + z,roipts = get_roipoints_from_siftpoints(p_siftpts,axis_range,roi_dims[0]) + if len(roipts==0): + roilist.append(None) + else: + if len(roipts.shape) == 1: + roipts = roipts[np.newaxis,:] + ptsmean = np.mean(roipts,axis=0) + y = ptsmean[1] + x = ptsmean[2] + roilist.append([[z,z+roi_dims[0]],[y,y+roi_dims[1]],[x,x+roi_dims[2]]]) + return roilist + + +def get_roipoints_from_siftpoints(p_siftpts,axis_range,roi_length): + if axis_range is None: + axis_range = [[] for i in range(p_siftpts.shape[1])] + if len(axis_range[0]) == 0: + axis_range[0] = [int(np.min(p_siftpts[:,0])),int(np.max(p_siftpts[:,0]))] + zstarts = np.arange(axis_range[0][0],axis_range[0][1],int(roi_length/2)) + p_pts = [] + z = zstarts[0] + for zs in zstarts: + r = np.full(p_siftpts.shape[0],True) + for ai,a in enumerate(axis_range): + if len(a)>0: + r = r & ((p_siftpts[:,ai]>=a[0]) & (p_siftpts[:,ai]<=a[1])) + elif ai == 0: + r = r & ((p_siftpts[:,ai]>=zs) & (p_siftpts[:,ai]<=zs+roi_length)) + pr = p_siftpts[r] + if len(pr) > len(p_pts): + p_pts = pr + z = zs + return z,p_pts + class SiftDetector(object): def __init__(self,clahe_kwargs,sift_kwargs,flann_args,ratio=0.7,min_inliers=100): # CLAHE equalizer @@ -128,13 +167,92 @@ def stitch_over_segments(self,p_dslist,q_dslist,zstarts,zlength,stitch_axes,**kw return p_ptlist,q_ptlist else: return None,None + + + def stitch_over_rois(self,p_dslist,q_dslist,roi_list,stitch_axes,ij_shift,**kwargs): + ''' + Parameters + ---------- + p_dslist : TYPE + DESCRIPTION. + q_dslist : TYPE + DESCRIPTION. + roi_list : TYPE + DESCRIPTION. + stitch_axes : TYPE + DESCRIPTION. + ij_shift : TYPE + DESCRIPTION. + **kwargs : TYPE + DESCRIPTION. + Returns + ------- + p_ptlist,q_ptlist: lists of ndarray zyx coordinates of point matches for each pq pair in dslists + + ''' + pmlist = [] + if stitch_axes == "zy": + for p_src,q_src,roi in zip(p_dslist,q_dslist,roi_list): + if not roi is None: + k1_tot,k2_tot,j_slice = self.zy_stitch(p_src,q_src,roi[0][0],roi[0][1],roi[2][0],roi[2][0]+ij_shift,**kwargs) + pmlist.append((k1_tot,k2_tot,j_slice)) + if pmlist: + pq_lists = [[],[]] + for i,s in enumerate(pmlist): + zi = roi_list[0][0] + xi = roi_list[0][2] + for ii,t in enumerate(pq_lists): + if not s is None: + kzyx = np.empty((s[ii].shape[0],3)) + kzyx[:,0] = s[ii][:,0] + zi + kzyx[:,1] = s[ii][:,1] + kzyx[:,2] = xi if ii == 0 else s[2] + t.append(kzyx) + else: + t.append(None) + p_ptlist = pq_lists[0] + q_ptlist = pq_lists[1] + return p_ptlist,q_ptlist + else: + return None,None + + + def zy_stitch(self,p_src,q_src,z0,z1,p_slice,q_slice,nx,dx,scatter=False): + ji = q_slice + imgRef = p_src[0,0,z0:z1,:,p_slice] + # detect SIFT keypoints for reference slice + kp1, des1, cimg1 = self.detect_keypoints(imgRef) + imgStack = q_src[0,0,z0:z1,:,(ji-nx*dx):(ji+(nx+1)*dx):dx] + # detect correspondences in slices from neighboring strip + k1_tot,k2_tot,good,k2slice = self.detect_and_combine(kp1,des1,cimg1,imgStack,False,axis=2,max_only=True) + print("Number of correspondences: " + str(good)) + if not k1_tot is None and k1_tot.shape[0]>100: + k = k2_tot-k1_tot + print('total correspondences for analysis: ' + str(k.shape[0])) + # estimate stitching translation with median displacement + km = np.median(k,axis=0) + print('median pixel displacements:' + str(km)) + # display scatter of displacements around median estimate + if scatter: + plt.scatter(k[:,0]-km[0],k[:,1]-km[1],s=1) + plt.xlim((-5,5)) + plt.ylim((-5,5)) + plt.show() + # identify slice index with most correspondences + j_slice = ji - nx*dx + dx*np.argmax(good) + return k1_tot,k2_tot,j_slice + else: + print("not enough correspondences") + return None,None,None + def run_zx_stitch(self, p_srclist, q_srclist, - z0,z1,i_slice,j_slice,ny,dy, + z0,z1,i_slice,ij_shift,ny,dy, scatter=False): + j_slice = i_slice + ij_shift # estimate translation between strips with the median 2D displacement of matching point correspondences siftklist = [] j_slices = np.ones(len(p_srclist),dtype=int)*j_slice @@ -182,11 +300,13 @@ def run_zx_stitch(self, t.append(None) return siftstitch + def run_zy_stitch(self, p_srclist, q_srclist, - z0,z1,i_slice,j_slice,nx,dx, + z0,z1,i_slice,ij_shift,nx,dx, scatter=False): + j_slice = i_slice + ij_shift # estimate translation between strips with the median 2D displacement of matching point correspondences siftklist = [] j_slices = np.ones(len(p_srclist),dtype=int)*j_slice diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 120d15d..b196a24 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -5,7 +5,7 @@ @author: kevint """ import numpy -from acpreprocessing.stitching_modules.acstitch.sift_stitch import SiftDetector +from acpreprocessing.stitching_modules.acstitch.sift_stitch import SiftDetector,generate_rois_from_pointmatches from acpreprocessing.stitching_modules.acstitch.ccorr_stitch import get_correspondences from acpreprocessing.stitching_modules.acstitch.zarrutils import get_group_from_src from acpreprocessing.stitching_modules.acstitch.io import read_pointmatch_file @@ -19,7 +19,11 @@ def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,sti sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"]) else: sift_pmlist = None - p_ptlist,q_ptlist = sd.stitch_over_segments(p_datasets,q_datasets,**stitch_kwargs) # zstarts, zlength, i_slice, j_slice, ny, dy) + if sift_pmlist is None: + p_ptlist,q_ptlist = sd.stitch_over_segments(p_datasets,q_datasets,**stitch_kwargs) # zstarts, zlength, stitch_axes, i_slice, j_slice, ny, dy + else: + roilist = generate_rois_from_pointmatches(pm_list=sift_pmlist,**stitch_kwargs) # axis_range, roi_dims, stitch_axes, ij_shift, nx, dx + p_ptlist,q_ptlist = sd.stitch_over_rois(p_datasets,q_datasets,roilist,**stitch_kwargs) pmlist = [] if not p_ptlist is None: for p_src,q_src,p_pts,q_pts in zip(p_srclist,q_srclist,p_ptlist,q_ptlist): From b955ccca1b01ca084d32ea65ef481362fdf6d05a Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Thu, 10 Aug 2023 20:20:04 -0700 Subject: [PATCH 24/37] fixing roi stitching --- .../stitching_modules/acstitch/sift_stitch.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py index 570b6f9..da53350 100644 --- a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py @@ -17,9 +17,13 @@ def generate_rois_from_pointmatches(pm_list,axis_range,roi_dims,**kwargs): if len(roipts.shape) == 1: roipts = roipts[np.newaxis,:] ptsmean = np.mean(roipts,axis=0) - y = ptsmean[1] - x = ptsmean[2] - roilist.append([[z,z+roi_dims[0]],[y,y+roi_dims[1]],[x,x+roi_dims[2]]]) + if roi_dims[1] is None: + x = ptsmean[2] + roi = [[z,z+roi_dims[0]],[],[x,x+roi_dims[2]]] + elif roi_dims[2] is None: + y = ptsmean[1] + roi = [[z,z+roi_dims[0]],[y,y+roi_dims[1]],[]] + roilist.append(roi) return roilist @@ -44,6 +48,7 @@ def get_roipoints_from_siftpoints(p_siftpts,axis_range,roi_length): z = zs return z,p_pts + class SiftDetector(object): def __init__(self,clahe_kwargs,sift_kwargs,flann_args,ratio=0.7,min_inliers=100): # CLAHE equalizer From 2ff41de83f4991bd8c97be5360cfa1e8c00f6b1a Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Thu, 10 Aug 2023 20:26:05 -0700 Subject: [PATCH 25/37] testing --- acpreprocessing/stitching_modules/acstitch/sift_stitch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py index da53350..4d7f170 100644 --- a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py @@ -24,6 +24,7 @@ def generate_rois_from_pointmatches(pm_list,axis_range,roi_dims,**kwargs): y = ptsmean[1] roi = [[z,z+roi_dims[0]],[y,y+roi_dims[1]],[]] roilist.append(roi) + print(roilist) return roilist From b10b2b8549196161df54af2e24f2a37678428963 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Thu, 10 Aug 2023 20:36:51 -0700 Subject: [PATCH 26/37] debugging --- acpreprocessing/stitching_modules/acstitch/sift_stitch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py index 4d7f170..2c7a0eb 100644 --- a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py @@ -11,7 +11,9 @@ def generate_rois_from_pointmatches(pm_list,axis_range,roi_dims,**kwargs): for pm in pm_list: p_siftpts = pm["p_pts"] z,roipts = get_roipoints_from_siftpoints(p_siftpts,axis_range,roi_dims[0]) + print(roipts) if len(roipts==0): + print("no roi points found") roilist.append(None) else: if len(roipts.shape) == 1: @@ -24,7 +26,6 @@ def generate_rois_from_pointmatches(pm_list,axis_range,roi_dims,**kwargs): y = ptsmean[1] roi = [[z,z+roi_dims[0]],[y,y+roi_dims[1]],[]] roilist.append(roi) - print(roilist) return roilist From 8b6ac04414d64dc756c5f97e7285a8257718ac51 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Thu, 10 Aug 2023 20:39:05 -0700 Subject: [PATCH 27/37] debugging --- acpreprocessing/stitching_modules/acstitch/sift_stitch.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py index 2c7a0eb..3c1e1ac 100644 --- a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py @@ -11,8 +11,7 @@ def generate_rois_from_pointmatches(pm_list,axis_range,roi_dims,**kwargs): for pm in pm_list: p_siftpts = pm["p_pts"] z,roipts = get_roipoints_from_siftpoints(p_siftpts,axis_range,roi_dims[0]) - print(roipts) - if len(roipts==0): + if len(roipts)==0: print("no roi points found") roilist.append(None) else: From f17e6a5f51498f4ae79b0e7b84d37e6415060ab7 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Thu, 10 Aug 2023 20:41:32 -0700 Subject: [PATCH 28/37] debugging --- acpreprocessing/stitching_modules/acstitch/sift_stitch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py index 3c1e1ac..c180375 100644 --- a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py @@ -19,10 +19,10 @@ def generate_rois_from_pointmatches(pm_list,axis_range,roi_dims,**kwargs): roipts = roipts[np.newaxis,:] ptsmean = np.mean(roipts,axis=0) if roi_dims[1] is None: - x = ptsmean[2] + x = int(ptsmean[2]) roi = [[z,z+roi_dims[0]],[],[x,x+roi_dims[2]]] elif roi_dims[2] is None: - y = ptsmean[1] + y = int(ptsmean[1]) roi = [[z,z+roi_dims[0]],[y,y+roi_dims[1]],[]] roilist.append(roi) return roilist From 4bf41ed2f9ed1b2f8f176567a769b2d3dbc73f5b Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Thu, 10 Aug 2023 20:44:26 -0700 Subject: [PATCH 29/37] debugging --- acpreprocessing/stitching_modules/acstitch/sift_stitch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py index c180375..0b574fd 100644 --- a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py @@ -224,7 +224,7 @@ def stitch_over_rois(self,p_dslist,q_dslist,roi_list,stitch_axes,ij_shift,**kwar return None,None - def zy_stitch(self,p_src,q_src,z0,z1,p_slice,q_slice,nx,dx,scatter=False): + def zy_stitch(self,p_src,q_src,z0,z1,p_slice,q_slice,nx,dx,scatter=False,**kwargs): ji = q_slice imgRef = p_src[0,0,z0:z1,:,p_slice] # detect SIFT keypoints for reference slice From 4528587439803db70bd01e4c05ddf1ded22c1f13 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Thu, 10 Aug 2023 21:28:41 -0700 Subject: [PATCH 30/37] reworking plane stitching --- acpreprocessing/stitching_modules/acstitch/stitch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index b196a24..1d58a64 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -48,7 +48,7 @@ def generate_ccorr_pointmatches(p_srclist,q_srclist,miplvl=0,ccorr_kwargs=None,s pds = p_datasets[i] qds = q_datasets[i] if not sift_pmlist is None: - if not sift_pmlist[i]["p_pts"] is None and len(sift_pmlist[i]["p_pts"])>0: + if i < len(sift_pmlist) and not sift_pmlist[i]["p_pts"] is None and len(sift_pmlist[i]["p_pts"])>0: ppts,qpts = run_ccorr_with_sift_points(pds, qds, sift_pmlist[i]["p_pts"].astype(int), sift_pmlist[i]["q_pts"].astype(int), **ccorr_kwargs) else: ppts = None From d8f3b9558a2835ff5623c3acf5e9286cea9a2b4d Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Fri, 11 Aug 2023 10:00:10 -0700 Subject: [PATCH 31/37] adding cc_threshold option --- acpreprocessing/stitching_modules/acstitch/ccorr_stitch.py | 4 ++-- acpreprocessing/stitching_modules/acstitch/stitch.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/ccorr_stitch.py b/acpreprocessing/stitching_modules/acstitch/ccorr_stitch.py index 4a3c258..3c984f6 100644 --- a/acpreprocessing/stitching_modules/acstitch/ccorr_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/ccorr_stitch.py @@ -2,7 +2,7 @@ from acpreprocessing.stitching_modules.acstitch.rtccorr import get_point_correspondence -def get_correspondences(A1_ds,A2_ds,A1_pts,A2_pts,w,pad=False): +def get_correspondences(A1_ds,A2_ds,A1_pts,A2_pts,w,pad=False,cc_threshold=0.8): w = numpy.asarray(w,dtype=int) if len(A1_pts.shape)<2: A1_pts = numpy.array([A1_pts]) @@ -12,7 +12,7 @@ def get_correspondences(A1_ds,A2_ds,A1_pts,A2_pts,w,pad=False): for p,q in zip(A1_pts.astype(int),A2_pts.astype(int)): A2sub = A2_ds[0,0,(q-w)[0]:(q+w)[0],(q-w)[1]:(q+w)[1],(q-w)[2]:(q+w)[2]] A1sub = A1_ds[0,0,(p-w)[0]:(p+w)[0],(p-w)[1]:(p+w)[1],(p-w)[2]:(p+w)[2]] - p1,p2 = get_point_correspondence(p,q,A1sub,A2sub,autocorrelation_threshold=0.8,padarray=pad,value_threshold=1000) + p1,p2 = get_point_correspondence(p,q,A1sub,A2sub,autocorrelation_threshold=cc_threshold,padarray=pad,value_threshold=500) if not p1 is None: pm1.append(p1) pm2.append(p2) diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 1d58a64..9f963ca 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -62,9 +62,9 @@ def generate_ccorr_pointmatches(p_srclist,q_srclist,miplvl=0,ccorr_kwargs=None,s return pmlist -def run_ccorr_with_sift_points(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_w=[32,32,32],pad_array=False,axis_shift=[0,0,0],axis_range=None,**kwargs): +def run_ccorr_with_sift_points(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_w=[32,32,32],pad_array=False,axis_shift=[0,0,0],axis_range=None,cc_threshold=0.8,**kwargs): p_pts,q_pts = get_cc_points_from_sift(p_ds, q_ds, p_siftpts, q_siftpts,n_cc_pts,axis_shift,axis_range) - ppm,qpm = get_correspondences(p_ds,q_ds,p_pts,q_pts,numpy.asarray(axis_w),pad_array) + ppm,qpm = get_correspondences(p_ds,q_ds,p_pts,q_pts,numpy.asarray(axis_w),pad_array,cc_threshold=cc_threshold) return ppm,qpm def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift=[0,0,0],axis_range=None): From 1d7fe71124f0e6efd8cee9befa194392a63d056c Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Mon, 11 Sep 2023 12:44:21 -0700 Subject: [PATCH 32/37] updating stitch --- acpreprocessing/stitching_modules/acstitch/stitch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 9f963ca..489d63c 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -83,13 +83,16 @@ def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift= elif ai == 0: r = r & ((p_siftpts[:,ai]>=zstarts[i]) & (p_siftpts[:,ai]<=zstarts[i+1])) pr = p_siftpts[r] + qr = q_siftpts[r] if len(pr) > 0: imax = numpy.argmax(p_ds[0,0,pr[:,0],pr[:,1],pr[:,2]]) ppt = pr[imax,:] + qpt = qr[imax,:] else: ppt = numpy.array([(zstarts[i]+zstarts[i+1])/2,numpy.mean(p_siftpts[:,1]),numpy.mean(p_siftpts[:,2])],dtype=int) + qpt = ppt + numpy.array(axis_shift) p_pts[i] = ppt - q_pts[i] = ppt + numpy.array(axis_shift) + q_pts[i] = qpt return p_pts,q_pts From ae0e3c7cc6037d7e376a578c5371a23382aedef3 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Thu, 14 Sep 2023 11:09:19 -0700 Subject: [PATCH 33/37] implementing pathlib instead of os --- .../stitching_modules/acstitch/zarrutils.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/zarrutils.py b/acpreprocessing/stitching_modules/acstitch/zarrutils.py index a1e6288..06234ab 100644 --- a/acpreprocessing/stitching_modules/acstitch/zarrutils.py +++ b/acpreprocessing/stitching_modules/acstitch/zarrutils.py @@ -1,4 +1,4 @@ -import os +import pathlib import zarr def get_zarr_group(zpath,grpname): @@ -7,10 +7,14 @@ def get_zarr_group(zpath,grpname): zf = zarr.open(zpath) return zf[grpname] -def get_group_from_src(srcpath): +def get_group_from_src(srcpath, + outpath='zarr://http://bigkahuna.corp.alleninstitute.org/ACdata', # Url for ACdata for NG hosted on BigKahuna + inpath = 'J:'): # returns zarr group given a neuroglancer source path # used to get datasets from neuroglancer layer json - pathout = 'zarr://http://bigkahuna.corp.alleninstitute.org/ACdata' # Url for ACdata for NG hosted on BigKahuna - pathin = 'J:' # Local path to ACdata - s = os.path.split(srcpath.replace(pathout,pathin)) - return get_zarr_group(zpath=s[0],grpname=s[1]) \ No newline at end of file + p = pathlib.Path(srcpath.replace(outpath,inpath)) + if p.exists(): + return get_zarr_group(p.parent,p.name) + else: + print(str(p) + " not found!") + return None \ No newline at end of file From 2b9e1267d6c9e18abce7ce5a7250a5486d456be2 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Thu, 14 Sep 2023 13:51:00 -0700 Subject: [PATCH 34/37] adding special utility function --- .../stitching_modules/acstitch/zarrutils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/acpreprocessing/stitching_modules/acstitch/zarrutils.py b/acpreprocessing/stitching_modules/acstitch/zarrutils.py index 06234ab..9af58ec 100644 --- a/acpreprocessing/stitching_modules/acstitch/zarrutils.py +++ b/acpreprocessing/stitching_modules/acstitch/zarrutils.py @@ -1,5 +1,6 @@ import pathlib import zarr +import json def get_zarr_group(zpath,grpname): # key to working with zarr files @@ -17,4 +18,11 @@ def get_group_from_src(srcpath, return get_zarr_group(p.parent,p.name) else: print(str(p) + " not found!") - return None \ No newline at end of file + return None + +def get_src_from_json(sourcejson,plane,tile): + with open(sourcejson,'r') as f: + js = json.load(f) + srcList = js[plane]['sources'] + ind = [s.split("_")[-1] for s in srcList].index(tile) + return srcList[ind] \ No newline at end of file From 836e41a54ef2b306095511565feb7b6789f87adc Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Thu, 26 Oct 2023 12:19:32 -0700 Subject: [PATCH 35/37] refactor for ortho exaspim stitching --- .../acstitch/ccorr_stitch.py | 10 +- .../stitching_modules/acstitch/sift_stitch.py | 276 +++++++++++------- .../stitching_modules/acstitch/stitch.py | 12 +- 3 files changed, 181 insertions(+), 117 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/ccorr_stitch.py b/acpreprocessing/stitching_modules/acstitch/ccorr_stitch.py index 3c984f6..6088969 100644 --- a/acpreprocessing/stitching_modules/acstitch/ccorr_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/ccorr_stitch.py @@ -2,7 +2,8 @@ from acpreprocessing.stitching_modules.acstitch.rtccorr import get_point_correspondence -def get_correspondences(A1_ds,A2_ds,A1_pts,A2_pts,w,pad=False,cc_threshold=0.8): +def get_correspondences(A1_ds,A2_ds,A1_pts,A2_pts,w,r=1,pad=False,cc_threshold=0.8,min_value=0): + cc_threshold /= r # TODO: is this the correct rescaling accounting for expanding reference? w = numpy.asarray(w,dtype=int) if len(A1_pts.shape)<2: A1_pts = numpy.array([A1_pts]) @@ -10,9 +11,12 @@ def get_correspondences(A1_ds,A2_ds,A1_pts,A2_pts,w,pad=False,cc_threshold=0.8): pm1 = [] pm2 = [] for p,q in zip(A1_pts.astype(int),A2_pts.astype(int)): - A2sub = A2_ds[0,0,(q-w)[0]:(q+w)[0],(q-w)[1]:(q+w)[1],(q-w)[2]:(q+w)[2]] + A2sub = A2_ds[0,0,(q-r*w)[0]:(q+r*w)[0],(q-r*w)[1]:(q+r*w)[1],(q-r*w)[2]:(q+r*w)[2]] A1sub = A1_ds[0,0,(p-w)[0]:(p+w)[0],(p-w)[1]:(p+w)[1],(p-w)[2]:(p+w)[2]] - p1,p2 = get_point_correspondence(p,q,A1sub,A2sub,autocorrelation_threshold=cc_threshold,padarray=pad,value_threshold=500) + if r > 1: + pw = numpy.asarray([((r-1)*wi,(r-1)*wi) for wi in w],dtype=int) + A1sub = numpy.pad(A1sub,pw) + p1,p2 = get_point_correspondence(p,q,A1sub,A2sub,autocorrelation_threshold=cc_threshold,padarray=pad,value_threshold=min_value) if not p1 is None: pm1.append(p1) pm2.append(p2) diff --git a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py index 0b574fd..463df2f 100644 --- a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py @@ -50,6 +50,134 @@ def get_roipoints_from_siftpoints(p_siftpts,axis_range,roi_length): return z,p_pts +def stitch_over_segments(p_dslist,q_dslist,zstarts,zlength,stitch_axes,sd_kwargs,**kwargs): + pmlist = [] + sd = SiftDetector(**sd_kwargs) + if stitch_axes == "zx": + for zs in zstarts: + seglist = sd.run_zx_stitch(p_dslist,q_dslist,zs,zs+zlength,**kwargs) + pmlist.append(seglist) + elif stitch_axes == "zy": + for zs in zstarts: + seglist = sd.run_zy_stitch(p_dslist,q_dslist,zs,zs+zlength,**kwargs) + pmlist.append(seglist) + if pmlist: + n_tiles = len(pmlist[0][0]) + tile_pmlist = [[[] for i in range(n_tiles)],[[] for i in range(n_tiles)]] + for pm in pmlist: + for i in range(2): + for ii in range(n_tiles): + if not pm[i][ii] is None: + tile_pmlist[i][ii].append(pm[i][ii]) + p_ptlist = [np.concatenate(pm) if pm else [] for pm in tile_pmlist[0]] + q_ptlist = [np.concatenate(pm) if pm else [] for pm in tile_pmlist[1]] + return p_ptlist,q_ptlist + else: + return None,None + + +def stitch_over_rois(sd_kwargs,p_dslist,q_dslist,axis_type,roi_list,ij_shift,ns,ds,s0=0,**kwargs): + ''' + Parameters + ---------- + p_dslist : TYPE + DESCRIPTION. + q_dslist : TYPE + DESCRIPTION. + axis_type : TYPE + DESCRIPTION. + roi_list : TYPE + DESCRIPTION. + stitch_axes : TYPE + DESCRIPTION. + ij_shift : TYPE + DESCRIPTION. + **kwargs : TYPE + DESCRIPTION. + + Returns + ------- + p_ptlist,q_ptlist: lists of ndarray zyx coordinates of point matches for each pq pair in dslists + + ''' + pmlist = [] + sd = SiftDetector(**sd_kwargs) + for p_src,q_src,roi in zip(p_dslist,q_dslist,roi_list): + if not roi is None: + # k1_tot,k2_tot,j_slice = self.zy_stitch(p_src,q_src,roi[0][0],roi[0][1],roi[2][0],roi[2][0]+ij_shift,**kwargs) + if axis_type == "ispim": + axis = 1 + z0,z1 = roi[0] + py = roi[1] + qs = roi[1] + ij_shift + s0 + x0,x1 = roi[2] + p_img = p_src[0,0,z0:z1,py,x0:x1] + q_stack = q_src[0,0,z0:z1,(qs-ns*ds):(qs+(ns+1)*ds):ds,x0:x1] + k2_add = np.array([0,0]) + elif axis_type == "zyx": + axis = 2 + z0,z1 = roi[0] + y0,y1 = roi[1] + px = roi[2] + qs = roi[2] + s0 + p_img = p_src[0,0,z0:z1,y0:y1,px] + q_stack = q_src[0,0,z0:z1,y0+ij_shift:y1+ij_shift,(qs-ns*ds):(qs+(ns+1)*ds):ds] + k2_add = np.array([ij_shift,0]) + elif axis_type == "xzy": + axis = 2 + z0,z1 = roi[1] + y0,y1 = roi[2] + px = roi[0] + qs = roi[0] + s0 + p_img = p_src[0,0,px,z0:z1,y0:y1] + q_stack = q_src[0,0,(qs-ns*ds):(qs+(ns+1)*ds):ds,z0:z1,y0+ij_shift:y1+ij_shift].transpose((1,2,0)) + k2_add = np.array([ij_shift,0]) + k1_tot,k2_tot,best_slice = sd.detect_in_best_slice(p_img, q_stack, axis=axis, **kwargs) + k2_tot += k2_add + if not k1_tot is None: + k2_slice = qs - ns*ds + ds*best_slice + pmlist.append((k1_tot,k2_tot,k2_slice)) + else: + pmlist.append(None) + else: + pmlist.append(None) + if pmlist: + pq_lists = [[],[]] + for s,roi in zip(pmlist,roi_list): + for ii,t in enumerate(pq_lists): + if not s is None: + k = np.empty((s[ii].shape[0],3)) + if axis_type == "ispim": + zi = roi[0][0] + xi = roi[2][0] + y = roi[1] + k[:,0] = s[ii][:,1] + zi + k[:,1] = y if ii == 0 else s[2] + k[:,2] = s[ii][:,0] + xi + elif axis_type == "zyx": + zi = roi[0][0] + yi = roi[1][0] + x = roi[2] + k[:,0] = s[ii][:,1] + zi + k[:,1] = s[ii][:,0] + yi + k[:,2] = x if ii == 0 else s[2] + elif axis_type == "xzy": + x = roi[0] + zi = roi[1][0] + yi = roi[2][0] + k[:,0] = x if ii == 0 else s[2] + k[:,1] = s[ii][:,1] + zi + k[:,2] = s[ii][:,0] + yi + t.append(k) + else: + t.append(None) + p_ptlist = pq_lists[0] + q_ptlist = pq_lists[1] + return p_ptlist,q_ptlist + else: + return None,None + + class SiftDetector(object): def __init__(self,clahe_kwargs,sift_kwargs,flann_args,ratio=0.7,min_inliers=100): # CLAHE equalizer @@ -62,16 +190,13 @@ def __init__(self,clahe_kwargs,sift_kwargs,flann_args,ratio=0.7,min_inliers=100) self.minin = min_inliers def detect_keypoints(self,img): - cimg1 = self.clahe.apply(img) - cimg1 = np.sqrt(cimg1).astype('uint8') - kp1, des1 = self.sift.detectAndCompute(cimg1,None) - return kp1,des1,cimg1 + cimg = self.clahe.apply(img) + cimg = np.sqrt(cimg).astype('uint8') + kp, des = self.sift.detectAndCompute(cimg,None) + return kp,des,cimg + - def detect_matches(self,kp1,des1,cimg1,img,draw=False): - cimg2 = self.clahe.apply(img) - cimg2 = np.sqrt(cimg2).astype('uint8') - # find the keypoints and descriptors with SIFT - kp2, des2 = self.sift.detectAndCompute(cimg2,None) + def compute_matches(self,kp1,des1,kp2,des2,draw=False,cimg1=None,cimg2=None): matches = self.flann.knnMatch(des1,des2,k=2) # Need to draw only good matches, so create a mask matchesMask = [[0,0] for i in range(len(matches))] @@ -82,8 +207,6 @@ def detect_matches(self,kp1,des1,cimg1,img,draw=False): matchesMask[i]=[1,0] good.append(m) #print(len(good)) - k1xy = np.array([np.array(k.pt) for k in kp1]) - k2xy = np.array([np.array(k.pt) for k in kp2]) k1 = [] k2 = [] if len(good)>self.minin: @@ -91,19 +214,15 @@ def detect_matches(self,kp1,des1,cimg1,img,draw=False): dst_pts = np.float32([ kp2[m.trainIdx].pt for m in good ]).reshape(-1,1,2) M, mask = cv.findHomography(src_pts, dst_pts, cv.RANSAC,5.0) matchesMask = mask.ravel().tolist() - if draw: - draw_params = dict(matchColor = (0,255,0), # draw matches in green color - singlePointColor = None, - matchesMask = matchesMask, # draw only inliers - flags = 2) - img4 = cv.drawMatches(cimg1,kp1,cimg2,kp2,good,None,**draw_params) - plt.figure(figsize=(20,20)) - plt.imshow(img4, 'gray'),plt.show() + if draw and not cimg1 is None and not cimg2 is None: + self.draw_matches(cimg1,kp1,cimg2,kp2,good,matchesMask) + goodMask = np.asarray(good)[np.asarray(matchesMask).astype('bool')] + imgIdx = np.array([g.imgIdx for g in goodMask]) + tIdx = np.array([g.trainIdx for g in goodMask]) + qIdx = np.array([g.queryIdx for g in goodMask]) + k1xy = np.array([np.array(k.pt) for k in kp1]) + k2xy = np.array([np.array(k.pt) for k in kp2]) - good = np.array(good)[np.array(matchesMask).astype('bool')] - imgIdx = np.array([g.imgIdx for g in good]) - tIdx = np.array([g.trainIdx for g in good]) - qIdx = np.array([g.queryIdx for g in good]) for i in range(len(tIdx)): if imgIdx[i] == 1: k1.append(k1xy[tIdx[i]]) @@ -111,7 +230,6 @@ def detect_matches(self,kp1,des1,cimg1,img,draw=False): else: k1.append(k1xy[qIdx[i]]) k2.append(k2xy[tIdx[i]]) - if len(k1)>0: k1 = np.array(k1) k2 = np.array(k2) @@ -121,8 +239,18 @@ def detect_matches(self,kp1,des1,cimg1,img,draw=False): np.random.shuffle(a) k1 = k1[a[0: 10000], :] k2 = k2[a[0: 10000], :] - - return k1,k2 + return k1,k2,good,matchesMask + + + def draw_matches(self,cimg1,kp1,cimg2,kp2,good,mask=[],**draw_kwargs): + draw_params = dict(matchColor = (0,255,0), # draw matches in green color + singlePointColor = None, + matchesMask = mask, # draw only inliers + flags = 2) + draw_params.update(draw_kwargs) + cvimg = cv.drawMatches(cimg1,kp1,cimg2,kp2,good,None,**draw_params) + return cvimg + def detect_and_combine(self,kp1,des1,cimg1,imgStack,draw=False,axis=2,max_only=False): k1_tot = [] @@ -134,7 +262,8 @@ def detect_and_combine(self,kp1,des1,cimg1,imgStack,draw=False,axis=2,max_only=F img = imgStack[:,:,i] elif axis==1: img = imgStack[:,i,:] - k1,k2 = self.detect_matches(kp1,des1,cimg1,img,draw) + kp2,des2,cimg2 = self.detect_keypoints(img) + k1,k2,_,__ = self.compute_matches(kp1,des1,kp2,des2) if isinstance(k1, np.ndarray): k1_tot.append(k1) k2_tot.append(k2) @@ -149,91 +278,17 @@ def detect_and_combine(self,kp1,des1,cimg1,imgStack,draw=False,axis=2,max_only=F else: return None,None,None,None - - def stitch_over_segments(self,p_dslist,q_dslist,zstarts,zlength,stitch_axes,**kwargs): - pmlist = [] - if stitch_axes == "zx": - for zs in zstarts: - seglist = self.run_zx_stitch(p_dslist,q_dslist,zs,zs+zlength,**kwargs) - pmlist.append(seglist) - elif stitch_axes == "zy": - for zs in zstarts: - seglist = self.run_zy_stitch(p_dslist,q_dslist,zs,zs+zlength,**kwargs) - pmlist.append(seglist) - if pmlist: - n_tiles = len(pmlist[0][0]) - tile_pmlist = [[[] for i in range(n_tiles)],[[] for i in range(n_tiles)]] - for pm in pmlist: - for i in range(2): - for ii in range(n_tiles): - if not pm[i][ii] is None: - tile_pmlist[i][ii].append(pm[i][ii]) - p_ptlist = [np.concatenate(pm) if pm else [] for pm in tile_pmlist[0]] - q_ptlist = [np.concatenate(pm) if pm else [] for pm in tile_pmlist[1]] - return p_ptlist,q_ptlist - else: - return None,None - - - def stitch_over_rois(self,p_dslist,q_dslist,roi_list,stitch_axes,ij_shift,**kwargs): - ''' - Parameters - ---------- - p_dslist : TYPE - DESCRIPTION. - q_dslist : TYPE - DESCRIPTION. - roi_list : TYPE - DESCRIPTION. - stitch_axes : TYPE - DESCRIPTION. - ij_shift : TYPE - DESCRIPTION. - **kwargs : TYPE - DESCRIPTION. - - Returns - ------- - p_ptlist,q_ptlist: lists of ndarray zyx coordinates of point matches for each pq pair in dslists - - ''' - pmlist = [] - if stitch_axes == "zy": - for p_src,q_src,roi in zip(p_dslist,q_dslist,roi_list): - if not roi is None: - k1_tot,k2_tot,j_slice = self.zy_stitch(p_src,q_src,roi[0][0],roi[0][1],roi[2][0],roi[2][0]+ij_shift,**kwargs) - pmlist.append((k1_tot,k2_tot,j_slice)) - if pmlist: - pq_lists = [[],[]] - for i,s in enumerate(pmlist): - zi = roi_list[0][0] - xi = roi_list[0][2] - for ii,t in enumerate(pq_lists): - if not s is None: - kzyx = np.empty((s[ii].shape[0],3)) - kzyx[:,0] = s[ii][:,0] + zi - kzyx[:,1] = s[ii][:,1] - kzyx[:,2] = xi if ii == 0 else s[2] - t.append(kzyx) - else: - t.append(None) - p_ptlist = pq_lists[0] - q_ptlist = pq_lists[1] - return p_ptlist,q_ptlist - else: - return None,None - - def zy_stitch(self,p_src,q_src,z0,z1,p_slice,q_slice,nx,dx,scatter=False,**kwargs): - ji = q_slice - imgRef = p_src[0,0,z0:z1,:,p_slice] + def detect_in_best_slice(self,p_img,q_stack,axis,scatter=False,**kwargs): + # ji = q_slice + # imgRef = p_src[0,0,z0:z1,:,p_slice] # detect SIFT keypoints for reference slice - kp1, des1, cimg1 = self.detect_keypoints(imgRef) - imgStack = q_src[0,0,z0:z1,:,(ji-nx*dx):(ji+(nx+1)*dx):dx] + kp1, des1, cimg1 = self.detect_keypoints(p_img) + # imgStack = q_src[0,0,z0:z1,:,(ji-nx*dx):(ji+(nx+1)*dx):dx] # detect correspondences in slices from neighboring strip - k1_tot,k2_tot,good,k2slice = self.detect_and_combine(kp1,des1,cimg1,imgStack,False,axis=2,max_only=True) + k1_tot,k2_tot,good,k2slice = self.detect_and_combine(kp1,des1,cimg1,q_stack,False,axis=axis,max_only=True) print("Number of correspondences: " + str(good)) - if not k1_tot is None and k1_tot.shape[0]>100: + if not k1_tot is None and k1_tot.shape[0]>10: k = k2_tot-k1_tot print('total correspondences for analysis: ' + str(k.shape[0])) # estimate stitching translation with median displacement @@ -246,8 +301,9 @@ def zy_stitch(self,p_src,q_src,z0,z1,p_slice,q_slice,nx,dx,scatter=False,**kwarg plt.ylim((-5,5)) plt.show() # identify slice index with most correspondences - j_slice = ji - nx*dx + dx*np.argmax(good) - return k1_tot,k2_tot,j_slice + # j_slice = ji - nx*dx + dx*np.argmax(good) + best_slice = np.argmax(good) + return k1_tot,k2_tot,best_slice else: print("not enough correspondences") return None,None,None diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 489d63c..73517e1 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -5,7 +5,7 @@ @author: kevint """ import numpy -from acpreprocessing.stitching_modules.acstitch.sift_stitch import SiftDetector,generate_rois_from_pointmatches +from acpreprocessing.stitching_modules.acstitch.sift_stitch import generate_rois_from_pointmatches,stitch_over_rois,stitch_over_segments from acpreprocessing.stitching_modules.acstitch.ccorr_stitch import get_correspondences from acpreprocessing.stitching_modules.acstitch.zarrutils import get_group_from_src from acpreprocessing.stitching_modules.acstitch.io import read_pointmatch_file @@ -14,16 +14,20 @@ def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,stitch_kwargs=None): p_datasets = [get_group_from_src(src)[miplvl] for src in p_srclist] q_datasets = [get_group_from_src(src)[miplvl] for src in q_srclist] - sd = SiftDetector(**sift_kwargs) + # sd = SiftDetector(**sift_kwargs) if "sift_pointmatch_file" in stitch_kwargs and stitch_kwargs["sift_pointmatch_file"]: sift_pmlist = read_pointmatch_file(stitch_kwargs["sift_pointmatch_file"]) else: sift_pmlist = None if sift_pmlist is None: - p_ptlist,q_ptlist = sd.stitch_over_segments(p_datasets,q_datasets,**stitch_kwargs) # zstarts, zlength, stitch_axes, i_slice, j_slice, ny, dy + if "roi_list" in stitch_kwargs and not stitch_kwargs["roi_list"] is None: + roilist = stitch_kwargs["roi_list"] + p_ptlist,q_ptlist = stitch_over_rois(p_datasets,q_datasets,roilist,sift_kwargs,**stitch_kwargs) + else: + p_ptlist,q_ptlist = stitch_over_segments(p_datasets,q_datasets,sift_kwargs,**stitch_kwargs) # zstarts, zlength, stitch_axes, i_slice, j_slice, ny, dy else: roilist = generate_rois_from_pointmatches(pm_list=sift_pmlist,**stitch_kwargs) # axis_range, roi_dims, stitch_axes, ij_shift, nx, dx - p_ptlist,q_ptlist = sd.stitch_over_rois(p_datasets,q_datasets,roilist,**stitch_kwargs) + p_ptlist,q_ptlist = stitch_over_rois(p_datasets,q_datasets,roilist,sift_kwargs,**stitch_kwargs) pmlist = [] if not p_ptlist is None: for p_src,q_src,p_pts,q_pts in zip(p_srclist,q_srclist,p_ptlist,q_ptlist): From 582209dd9a5abd0476b050f8ee25f6307a672947 Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Thu, 26 Oct 2023 12:59:28 -0700 Subject: [PATCH 36/37] update legacy stitching functions to run with refactor --- .../stitching_modules/acstitch/sift_stitch.py | 60 +++++++++++-------- .../stitching_modules/acstitch/stitch.py | 8 +-- 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py index 463df2f..8942486 100644 --- a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py @@ -50,30 +50,42 @@ def get_roipoints_from_siftpoints(p_siftpts,axis_range,roi_length): return z,p_pts -def stitch_over_segments(p_dslist,q_dslist,zstarts,zlength,stitch_axes,sd_kwargs,**kwargs): - pmlist = [] - sd = SiftDetector(**sd_kwargs) - if stitch_axes == "zx": - for zs in zstarts: - seglist = sd.run_zx_stitch(p_dslist,q_dslist,zs,zs+zlength,**kwargs) - pmlist.append(seglist) - elif stitch_axes == "zy": - for zs in zstarts: - seglist = sd.run_zy_stitch(p_dslist,q_dslist,zs,zs+zlength,**kwargs) - pmlist.append(seglist) - if pmlist: - n_tiles = len(pmlist[0][0]) - tile_pmlist = [[[] for i in range(n_tiles)],[[] for i in range(n_tiles)]] - for pm in pmlist: - for i in range(2): - for ii in range(n_tiles): - if not pm[i][ii] is None: - tile_pmlist[i][ii].append(pm[i][ii]) - p_ptlist = [np.concatenate(pm) if pm else [] for pm in tile_pmlist[0]] - q_ptlist = [np.concatenate(pm) if pm else [] for pm in tile_pmlist[1]] - return p_ptlist,q_ptlist - else: - return None,None +def stitch_over_segments(sd_kwargs,p_dslist,q_dslist,zstarts,zlength,i_slice,ij_shift,ns,ds,s0=0,**kwargs): + ''' stitch by segment for ispim data (legacy, axis_type = "ispim") + ''' + xdim = p_dslist[0].shape[4] + roi_list = [[[z,z+zlength],i_slice,[0,xdim]] for z in zstarts] + p_ptlist,q_ptlist = stitch_over_rois(sd_kwargs=sd_kwargs, + p_dslist=p_dslist, + q_dslist=q_dslist, + axis_type="ispim", + roi_list=roi_list, + ij_shift=ij_shift, + ns=ns, + ds=ds, + s0=s0) + return p_ptlist,q_ptlist + # if stitch_axes == "zx": + # for zs in zstarts: + # seglist = sd.run_zx_stitch(p_dslist,q_dslist,zs,zs+zlength,**kwargs) + # pmlist.append(seglist) + # elif stitch_axes == "zy": + # for zs in zstarts: + # seglist = sd.run_zy_stitch(p_dslist,q_dslist,zs,zs+zlength,**kwargs) + # pmlist.append(seglist) + # if pmlist: + # n_tiles = len(pmlist[0][0]) + # tile_pmlist = [[[] for i in range(n_tiles)],[[] for i in range(n_tiles)]] + # for pm in pmlist: + # for i in range(2): + # for ii in range(n_tiles): + # if not pm[i][ii] is None: + # tile_pmlist[i][ii].append(pm[i][ii]) + # p_ptlist = [np.concatenate(pm) if pm else [] for pm in tile_pmlist[0]] + # q_ptlist = [np.concatenate(pm) if pm else [] for pm in tile_pmlist[1]] + # return p_ptlist,q_ptlist + # else: + # return None,None def stitch_over_rois(sd_kwargs,p_dslist,q_dslist,axis_type,roi_list,ij_shift,ns,ds,s0=0,**kwargs): diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 73517e1..0cf2ded 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -22,12 +22,12 @@ def generate_sift_pointmatches(p_srclist,q_srclist,miplvl=0,sift_kwargs=None,sti if sift_pmlist is None: if "roi_list" in stitch_kwargs and not stitch_kwargs["roi_list"] is None: roilist = stitch_kwargs["roi_list"] - p_ptlist,q_ptlist = stitch_over_rois(p_datasets,q_datasets,roilist,sift_kwargs,**stitch_kwargs) + p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_datasets,q_datasets,roilist,**stitch_kwargs) else: - p_ptlist,q_ptlist = stitch_over_segments(p_datasets,q_datasets,sift_kwargs,**stitch_kwargs) # zstarts, zlength, stitch_axes, i_slice, j_slice, ny, dy + p_ptlist,q_ptlist = stitch_over_segments(sift_kwargs,p_datasets,q_datasets,**stitch_kwargs) # zstarts, zlength, i_slice, ij_shift, ns, ds else: roilist = generate_rois_from_pointmatches(pm_list=sift_pmlist,**stitch_kwargs) # axis_range, roi_dims, stitch_axes, ij_shift, nx, dx - p_ptlist,q_ptlist = stitch_over_rois(p_datasets,q_datasets,roilist,sift_kwargs,**stitch_kwargs) + p_ptlist,q_ptlist = stitch_over_rois(sift_kwargs,p_datasets,q_datasets,roilist,**stitch_kwargs) pmlist = [] if not p_ptlist is None: for p_src,q_src,p_pts,q_pts in zip(p_srclist,q_srclist,p_ptlist,q_ptlist): @@ -68,7 +68,7 @@ def generate_ccorr_pointmatches(p_srclist,q_srclist,miplvl=0,ccorr_kwargs=None,s def run_ccorr_with_sift_points(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_w=[32,32,32],pad_array=False,axis_shift=[0,0,0],axis_range=None,cc_threshold=0.8,**kwargs): p_pts,q_pts = get_cc_points_from_sift(p_ds, q_ds, p_siftpts, q_siftpts,n_cc_pts,axis_shift,axis_range) - ppm,qpm = get_correspondences(p_ds,q_ds,p_pts,q_pts,numpy.asarray(axis_w),pad_array,cc_threshold=cc_threshold) + ppm,qpm = get_correspondences(p_ds,q_ds,p_pts,q_pts,numpy.asarray(axis_w),pad=pad_array,cc_threshold=cc_threshold) return ppm,qpm def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift=[0,0,0],axis_range=None): From 533843103652806e59a03803898f8f086df3275c Mon Sep 17 00:00:00 2001 From: Kevin Takasaki Date: Thu, 26 Oct 2023 13:01:11 -0700 Subject: [PATCH 37/37] cleaning up obsolete code --- .../stitching_modules/acstitch/sift_stitch.py | 132 +----------------- .../stitching_modules/acstitch/stitch.py | 1 + 2 files changed, 2 insertions(+), 131 deletions(-) diff --git a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py index 8942486..16ec00c 100644 --- a/acpreprocessing/stitching_modules/acstitch/sift_stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/sift_stitch.py @@ -65,27 +65,6 @@ def stitch_over_segments(sd_kwargs,p_dslist,q_dslist,zstarts,zlength,i_slice,ij_ ds=ds, s0=s0) return p_ptlist,q_ptlist - # if stitch_axes == "zx": - # for zs in zstarts: - # seglist = sd.run_zx_stitch(p_dslist,q_dslist,zs,zs+zlength,**kwargs) - # pmlist.append(seglist) - # elif stitch_axes == "zy": - # for zs in zstarts: - # seglist = sd.run_zy_stitch(p_dslist,q_dslist,zs,zs+zlength,**kwargs) - # pmlist.append(seglist) - # if pmlist: - # n_tiles = len(pmlist[0][0]) - # tile_pmlist = [[[] for i in range(n_tiles)],[[] for i in range(n_tiles)]] - # for pm in pmlist: - # for i in range(2): - # for ii in range(n_tiles): - # if not pm[i][ii] is None: - # tile_pmlist[i][ii].append(pm[i][ii]) - # p_ptlist = [np.concatenate(pm) if pm else [] for pm in tile_pmlist[0]] - # q_ptlist = [np.concatenate(pm) if pm else [] for pm in tile_pmlist[1]] - # return p_ptlist,q_ptlist - # else: - # return None,None def stitch_over_rois(sd_kwargs,p_dslist,q_dslist,axis_type,roi_list,ij_shift,ns,ds,s0=0,**kwargs): @@ -318,113 +297,4 @@ def detect_in_best_slice(self,p_img,q_stack,axis,scatter=False,**kwargs): return k1_tot,k2_tot,best_slice else: print("not enough correspondences") - return None,None,None - - - def run_zx_stitch(self, - p_srclist, - q_srclist, - z0,z1,i_slice,ij_shift,ny,dy, - scatter=False): - j_slice = i_slice + ij_shift - # estimate translation between strips with the median 2D displacement of matching point correspondences - siftklist = [] - j_slices = np.ones(len(p_srclist),dtype=int)*j_slice - for i,dsRef in enumerate(p_srclist): - ji = j_slices[i] - # iterate over each strip and its subsequent neighbor to look for correspondences and estimate median - dsStack = q_srclist[i] - imgRef = dsRef[0,0,z0:z1,i_slice,:] - # detect SIFT keypoints for reference slice - kp1, des1, cimg1 = self.detect_keypoints(imgRef) - imgStack = dsStack[0,0,z0:z1,(ji-ny*dy):(ji+(ny+1)*dy):dy,:] - # detect correspondences in slices from neighboring strip - k1_tot,k2_tot,good,k2slice = self.detect_and_combine(kp1,des1,cimg1,imgStack,False,axis=1,max_only=True) - print("Number of correspondences: " + str(good)) - if not k1_tot is None and k1_tot.shape[0]>200: - k = k2_tot-k1_tot - print('total correspondences for analysis: ' + str(k.shape[0])) - # estimate stitching translation with median displacement - km = np.median(k,axis=0) - print('median pixel displacements:' + str(km)) - # display scatter of displacements around median estimate - if scatter: - plt.scatter(k[:,0]-km[0],k[:,1]-km[1],s=1) - plt.xlim((-5,5)) - plt.ylim((-5,5)) - plt.show() - # identify slice index with most correspondences - j_slices[i] = ji - ny*dy + dy*np.argmax(good) - siftklist.append((k1_tot,k2_tot,k2slice)) - else: - print("not enough correspondences for strip " + str(i)) - siftklist.append(None) - siftstitch = [[],[]] - for i,s in enumerate(siftklist): - zi = z0 - yi = j_slices[i] - for ii,t in enumerate(siftstitch): - if not s is None: - kzyx = np.empty((s[ii].shape[0],3)) - kzyx[:,0] = s[ii][:,0] + zi - kzyx[:,1] = i_slice if ii == 0 else yi - kzyx[:,2] = s[ii][:,1] - t.append(kzyx) - else: - t.append(None) - return siftstitch - - - def run_zy_stitch(self, - p_srclist, - q_srclist, - z0,z1,i_slice,ij_shift,nx,dx, - scatter=False): - j_slice = i_slice + ij_shift - # estimate translation between strips with the median 2D displacement of matching point correspondences - siftklist = [] - j_slices = np.ones(len(p_srclist),dtype=int)*j_slice - for i,dsRef in enumerate(p_srclist): - ji = j_slices[i] - # iterate over each strip and its subsequent neighbor to look for correspondences and estimate median - dsStack = q_srclist[i] - imgRef = dsRef[0,0,z0:z1,:,i_slice] - # detect SIFT keypoints for reference slice - kp1, des1, cimg1 = self.detect_keypoints(imgRef) - imgStack = dsStack[0,0,z0:z1,:,(ji-nx*dx):(ji+(nx+1)*dx):dx] - # detect correspondences in slices from neighboring strip - k1_tot,k2_tot,good,k2slice = self.detect_and_combine(kp1,des1,cimg1,imgStack,False,axis=2,max_only=True) - print("Number of correspondences: " + str(good)) - if not k1_tot is None and k1_tot.shape[0]>100: - k = k2_tot-k1_tot - print('total correspondences for analysis: ' + str(k.shape[0])) - # estimate stitching translation with median displacement - km = np.median(k,axis=0) - print('median pixel displacements:' + str(km)) - # display scatter of displacements around median estimate - if scatter: - plt.scatter(k[:,0]-km[0],k[:,1]-km[1],s=1) - plt.xlim((-5,5)) - plt.ylim((-5,5)) - plt.show() - # identify slice index with most correspondences - j_slices[i] = ji - nx*dx + dx*np.argmax(good) - siftklist.append((k1_tot,k2_tot,k2slice)) - else: - print("not enough correspondences for strip " + str(i)) - siftklist.append(None) - - siftstitch = [[],[]] - for i,s in enumerate(siftklist): - zi = z0 - xi = j_slices[i] - for ii,t in enumerate(siftstitch): - if not s is None: - kzyx = np.empty((s[ii].shape[0],3)) - kzyx[:,0] = s[ii][:,0] + zi - kzyx[:,1] = s[ii][:,1] - kzyx[:,2] = i_slice if ii == 0 else xi - t.append(kzyx) - else: - t.append(None) - return siftstitch \ No newline at end of file + return None,None,None \ No newline at end of file diff --git a/acpreprocessing/stitching_modules/acstitch/stitch.py b/acpreprocessing/stitching_modules/acstitch/stitch.py index 0cf2ded..9f20429 100644 --- a/acpreprocessing/stitching_modules/acstitch/stitch.py +++ b/acpreprocessing/stitching_modules/acstitch/stitch.py @@ -71,6 +71,7 @@ def run_ccorr_with_sift_points(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_w=[ ppm,qpm = get_correspondences(p_ds,q_ds,p_pts,q_pts,numpy.asarray(axis_w),pad=pad_array,cc_threshold=cc_threshold) return ppm,qpm + def get_cc_points_from_sift(p_ds,q_ds,p_siftpts,q_siftpts,n_cc_pts=1,axis_shift=[0,0,0],axis_range=None): # TODO: handle overly granular bins with potentially 0 sift points returned if axis_range is None: