Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
João Mamede committed Nov 3, 2023
1 parent 001969f commit bdedcdf
Showing 1 changed file with 111 additions and 59 deletions.
170 changes: 111 additions & 59 deletions src/napari_trackpy/_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,56 +62,93 @@ def _get_choice_layer(self,_widget):
print("Layer where points are is:",j)
return index_layer

def make_labels_trackpy_links(shape,j,radius=5,_round=False):
def make_labels_trackpy_links(shape,j,radius=5,_round=False,_algo="CPU"):
import trackpy as tp
import scipy.ndimage as ndi
from scipy.ndimage import binary_dilation
import cupy as cp

#outputsomehow is 3D, we want 2
pos = cp.dstack((round(j.y),round(j.x)))[0].astype(int)

#this is super slow
# ~ masks = tp.masks.mask_image(coords,np.ones(image.shape),size/2)

##this is what tp.masks.mask_image does maybe put a cupy here to make if faster.
ndim = len(shape)
# radius = validate_tuple(radius, ndim)
pos = cp.atleast_2d(pos)

# if include_edge:
in_mask = cp.array([cp.sum(((cp.indices(shape).T - p) / radius)**2, -1) <= 1
for p in pos])
# else:
# in_mask = [np.sum(((np.indices(shape).T - p) / radius)**2, -1) < 1
# for p in pos]
mask_total = cp.any(in_mask, axis=0).T

##if they overlap the labels won't match the points
#we can make np.ones * ID of the point and then np.max(axis=-1)
labels, nb = ndi.label(cp.asnumpy(mask_total))
# image * mask_cluster.astype(np.uint8)
if _algo == "GPU":
import cupy as cp

#outputsomehow is 3D, we want 2
# pos = cp.dstack((round(j.y),round(j.x)))[0].astype(int)
# if j.z:
if 'z' in j:
# "Need to loop each t and do one at a time"
pos = cp.dstack((j.z,j.y,j.x))[0].astype(int)
print("3D",j)
else:
pos = cp.dstack((j.y,j.x))[0].astype(int)
print("2D",j)


##this is what tp.masks.mask_image does maybe put a cupy here to make if faster.
ndim = len(shape)
# radius = validate_tuple(radius, ndim)
pos = cp.atleast_2d(pos)

# if include_edge:
in_mask = cp.array([cp.sum(((cp.indices(shape).T - p) / radius)**2, -1) <= 1
for p in pos])
# else:
# in_mask = [np.sum(((np.indices(shape).T - p) / radius)**2, -1) < 1
# for p in pos]
mask_total = cp.any(in_mask, axis=0).T

##if they overlap the labels won't match the points
#we can make np.ones * ID of the point and then np.max(axis=-1)
labels, nb = ndi.label(cp.asnumpy(mask_total))
# image * mask_cluster.astype(np.uint8)

#this is super slow
# ~ masks = tp.masks.mask_image(coords,np.ones(image.shape),size/2)
elif _algo=='CPU':
if 'z' in j:
# "Need to loop each t and do one at a time"
pos = np.dstack((j.z,j.y,j.x))[0].astype(int)
print("3D",j)
else:
pos = np.dstack((j.y,j.x))[0].astype(int)
print("2D",j)


##this is what tp.masks.mask_image does maybe put a cupy here to make if faster.
ndim = len(shape)
# radius = validate_tuple(radius, ndim)
pos = np.atleast_2d(pos)

# if include_edge:
in_mask = np.array([np.sum(((np.indices(shape).T - p) / radius)**2, -1) <= 1
for p in pos])
# else:
# in_mask = [np.sum(((np.indices(shape).T - p) / radius)**2, -1) < 1
# for p in pos]
mask_total = np.any(in_mask, axis=0).T

##if they overlap the labels won't match the points
#we can make np.ones * ID of the point and then np.max(axis=-1)
labels, nb = ndi.label(mask_total)
elif _algo=='fast':
#This is faster
# r = (size-1)/2 # Radius of circles
# #make 3D compat
# disk_mask = tp.masks.binary_mask(r,image.ndim)
# # Initialize output array and set the maskcenters as 1s
# out = np.zeros(image.shape,dtype=bool)
# #check if there's a problem with subpixel masking
# out[coords[:,0],coords[:,1]] = 1
# # Use binary dilation to get the desired output
# out = binary_dilation(out,disk_mask)

# labels, nb = ndi.label(out)
# if _round:
# return labels, coords
# else:
# if image.ndim == 2:
# # coords = j.loc[:,['particle','frame','y','x']]
# coords = j.loc[:,['frame','y','x']]
# # coords = np.dstack((j.particle,j.y,j.x))[0]
# return labels, coords
r = (size-1)/2 # Radius of circles
# #make 3D compat
disk_mask = tp.masks.binary_mask(r,image.ndim)
# # Initialize output array and set the maskcenters as 1s
out = np.zeros(image.shape,dtype=bool)
# #check if there's a problem with subpixel masking
out[coords[:,0],coords[:,1]] = 1
# # Use binary dilation to get the desired output
out = binary_dilation(out,disk_mask)

labels, nb = ndi.label(out)
# if _round:
# return labels, coords
# else:
# if image.ndim == 2:
# # coords = j.loc[:,['particle','frame','y','x']]
# coords = j.loc[:,['frame','y','x']]
# # coords = np.dstack((j.particle,j.y,j.x))[0]
# return labels, coords
return labels, pos

class IdentifyQWidget(QWidget):
Expand Down Expand Up @@ -299,13 +336,18 @@ def _select_layer(self,i):


def make_masks(self):
import pandas as pd
import pandas as pd\
# if self.viewer.layers[index_layer].scale[0] != 1:
index_layer = _get_choice_layer(self,self.layersbox)
if len(self.viewer.layers[0].data.shape) <= 3:

if self.viewer.layers.selection.active.data.shape[1] <= 3:
##fix here to distinguish between ZYX TYX
df = pd.DataFrame(self.viewer.layers.selection.active.data, columns = ['frame','y','x'])
elif len(self.viewer.layers[0].data.shape) > 3:
elif self.viewer.layers.selection.active.data.shape[1] > 3:
# if self.viewer.layers[index_layer].scale[0] != 1
df = pd.DataFrame(self.viewer.layers.selection.active.data, columns = ['frame','z','y','x'])
# else:
# df = pd.DataFrame(self.viewer.layers.selection.active.data, columns = ['frame','y','x'])
b = self.viewer.layers.selection.active.properties
for key in b.keys():
df[key] = b[key]
Expand All @@ -322,7 +364,7 @@ def make_masks(self):
mask_temp, idx_temp = make_labels_trackpy_links(
self.viewer.layers[index_layer].data[i].shape,
# self.viewer.layers[index_layer].data[i],
temp,radius=(self.diameter_input.value()-1)/2)
temp,radius=(self.diameter_input.value())/2)
# temp,size=5-1)
# print(mask_temp.max(),len(temp.index))
# idx.append(idx_temp)
Expand All @@ -348,10 +390,10 @@ def _on_click(self):
print("Detected more than 3 dimensions")
if self.choice.isChecked():
print("Detected a Time lapse TYX or TZYX image")
a = self.viewer.layers[index_layer].data[self.min_timer.value():self.max_timer.value()]
self.f = tp.batch(a,self.diameter_input.value(),minmass=self.mass_slider.value(),
img = np.asarray(self.viewer.layers[index_layer].data[self.min_timer.value():self.max_timer.value()])
self.f = tp.batch(img,self.diameter_input.value(),minmass=self.mass_slider.value(),
engine="numba",
processes=1,
# processes=1,
)
#TODO
#if min is not 0 we have to adjust F to bump it up
Expand All @@ -361,7 +403,8 @@ def _on_click(self):
#however if there's a 1um Z stack it will bug
if self.viewer.layers[index_layer].scale[0] != 1:
print("Detected a ZYX image")
self.f = tp.locate(self.viewer.layers[index_layer].data,self.diameter_input.value(),minmass=self.mass_slider.value())
img = np.asarray(self.viewer.layers[index_layer].data)
self.f = tp.locate(img,self.diameter_input.value(),minmass=self.mass_slider.value())
self.f['frame'] = 0
else:
print("Detected a Time lapse ZYX image")
Expand All @@ -372,7 +415,8 @@ def _on_click(self):
self.f['frame'] = _time_locator
elif len(self.viewer.layers[index_layer].data.shape) == 2:
print("Detected only YX")
self.f = tp.locate(self.viewer.layers[index_layer].data,self.diameter_input.value(),minmass=self.mass_slider.value())
img = np.asarray(self.viewer.layers[index_layer].data)
self.f = tp.locate(img,self.diameter_input.value(),minmass=self.mass_slider.value())
self.f['frame'] = 0
#TODO

Expand Down Expand Up @@ -423,12 +467,18 @@ def _on_click2(self):
]


if len(self.viewer.layers[index_layer].data.shape) <= 3:
if len(self.viewer.layers[index_layer].data.shape) < 3:
_points = f2.loc[:,['frame','y','x']]
elif len(self.viewer.layers[index_layer].data.shape) > 3:
_points = f2.loc[:,['frame','z','y','x']]
elif len(self.viewer.layers[index_layer].data.shape) >= 3:
if self.viewer.layers[index_layer].scale[0] != 1:
_points = f2.loc[:,['frame','z','y','x']]
else:
_points = f2.loc[:,['frame','y','x']]
try:
_metadata = f2.loc[:,['mass','size','ecc']]
except:
_metadata = f2.loc[:,['mass','size']]

_metadata = f2.loc[:,['mass','size','ecc']]
self.f2 = f2
self._points_layer_filter = self.viewer.add_points(_points,properties=_metadata,**self.points_options2)
self._points_layer_filter.scale = self.viewer.layers[index_layer].scale
Expand All @@ -437,11 +487,13 @@ def _save_results(self):
import pandas as pd
##TODO
##pull from points layer see example below
if len(self.viewer.layers[index_layer].data.shape) <= 3:
selected_layer = self.viewer.layers.selection
if len(self.viewer.layers[selected_layer].data.shape) <= 3:
#manbearpig time lapse vs Zstack
df = pd.DataFrame(self.viewer.layers.selection.active.data, columns = ['frame','y','x'])
elif len(self.viewer.layers[index_layer].data.shape) > 3:
elif len(self.viewer.layers[selected_layer].data.shape) > 3:
df = pd.DataFrame(self.viewer.layers.selection.active.data, columns = ['frame','z','y','x'])

b = self.viewer.layers.selection.active.properties
for key in b.keys():
df[key] = b[key]
Expand Down

0 comments on commit bdedcdf

Please sign in to comment.