Skip to content

Commit

Permalink
Update _widget.py
Browse files Browse the repository at this point in the history
  • Loading branch information
joaomamede authored Feb 24, 2024
1 parent 4b3a138 commit 0120c27
Showing 1 changed file with 144 additions and 45 deletions.
189 changes: 144 additions & 45 deletions src/napari_trackpy/_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _get_open_filename(self,type='image',separator= " :: ",choice_widget=None):
_filename = os.path.splitext(self.viewer.layers[j]._source.path)[0]
return _filename

def make_labels_trackpy_links(shape,j,radius=5,_round=False,_algo="GPU"):
def make_labels_trackpy_links(shape,j,radius=5,_algo="GPU"):
import trackpy as tp
import scipy.ndimage as ndi
from scipy.ndimage import binary_dilation
Expand All @@ -83,10 +83,10 @@ def make_labels_trackpy_links(shape,j,radius=5,_round=False,_algo="GPU"):
# if j.z:
if 'z' in j:
# "Need to loop each t and do one at a time"
pos = cp.dstack((j.z,j.y,j.x))[0].astype(int)
pos = cp.dstack((j.z,j.y,j.x))[0]#.astype(int)
print("3D",j)
else:
pos = cp.dstack((j.y,j.x))[0].astype(int)
pos = cp.dstack((j.y,j.x))[0]#.astype(int)
print("2D",j)


Expand All @@ -113,42 +113,54 @@ def make_labels_trackpy_links(shape,j,radius=5,_round=False,_algo="GPU"):
elif _algo=='CPU':
if 'z' in j:
# "Need to loop each t and do one at a time"
pos = np.dstack((j.z,j.y,j.x))[0].astype(int)
pos = np.dstack((j.z,j.y,j.x))[0]#.astype(int)
print("3D",j)
else:
pos = np.dstack((j.y,j.x))[0].astype(int)
pos = np.dstack((j.y,j.x))[0]#.astype(int)
print("2D",j)


##this is what tp.masks.mask_image does maybe put a cupy here to make if faster.
ndim = len(shape)
# radius = validate_tuple(radius, ndim)
pos = np.atleast_2d(pos)

# if include_edge:
in_mask = np.array([np.sum(((np.indices(shape).T - p) / radius)**2, -1) <= 1
for p in pos])
# else:
# in_mask = [np.sum(((np.indices(shape).T - p) / radius)**2, -1) < 1
# for p in pos]
mask_total = np.any(in_mask, axis=0).T

##if they overlap the labels won't match the points
#we can make np.ones * ID of the point and then np.max(axis=-1)
labels, nb = ndi.label(mask_total)
elif _algo=='fast':
#This is faster
r = (size-1)/2 # Radius of circles

# r = (radius-1)/2 # Radius of circles
# print(radius,r)
# #make 3D compat
disk_mask = tp.masks.binary_mask(r,image.ndim)
disk_mask = tp.masks.binary_mask(radius,len(shape))
# print(disk_mask)
# # Initialize output array and set the maskcenters as 1s
out = np.zeros(image.shape,dtype=bool)
# #check if there's a problem with subpixel masking
out[coords[:,0],coords[:,1]] = 1
out = np.zeros(shape,dtype=bool)

if 'z' in j:
pos = np.dstack((j.z,j.y,j.x))[0].astype(int)
pos = np.atleast_2d(pos)
print(pos)
out[pos[:,0],pos[:,1],pos[:,2]] = 1

else:
pos = np.dstack((j.y,j.x))[0].astype(int)
pos = np.atleast_2d(pos)
print(pos)
out[pos[:,0],pos[:,1]] = 1
# # Use binary dilation to get the desired output

out = binary_dilation(out,disk_mask)

labels, nb = ndi.label(out)
print("Number of labels:",nb)
# if _round:
# return labels, coords
# else:
Expand All @@ -157,6 +169,55 @@ def make_labels_trackpy_links(shape,j,radius=5,_round=False,_algo="GPU"):
# coords = j.loc[:,['frame','y','x']]
# # coords = np.dstack((j.particle,j.y,j.x))[0]
# return labels, coords
# elif _algo == 'openCL':
# def opencl_function(ctx, queue, shape, pos, radius, include_edge=True):
# ndim = len(shape)

# # Flatten the shape to a single value for the buffer
# shape_flat = np.prod(shape)

# # Set up OpenCL buffers
# shape_buf = cl.Buffer(ctx, cl.mem_flags.READ_ONLY | cl.mem_flags.COPY_HOST_PTR, hostbuf=shape)
# pos_buf = cl.Buffer(ctx, cl.mem_flags.READ_ONLY | cl.mem_flags.COPY_HOST_PTR, hostbuf=pos)
# radius_buf = cl.Buffer(ctx, cl.mem_flags.READ_ONLY | cl.mem_flags.COPY_HOST_PTR, hostbuf=radius)

# # Create an OpenCL buffer for the result
# result = np.empty((pos.shape[0],) + shape, dtype=np.bool)
# result_buf = cl.Buffer(ctx, cl.mem_flags.WRITE_ONLY, result.nbytes)

# # Write OpenCL kernel
# kernel_code = """
# __kernel void mask_kernel(const int ndim,
# const int shape_flat,
# __global const int* shape,
# __global const float* pos,
# __global const float* radius,
# const int include_edge,
# __global int* result) {
# int gid = get_global_id(0);
# int pid = get_global_id(1);

# float distance_squared = 0.0;

# for (int d = 0; d < ndim; ++d) {
# float diff = (get_global_size(1) > 1) ? (pos[pid * ndim + d] - gid % shape[d]) : 0.0;
# distance_squared += diff * diff;
# }

# result[pid * shape_flat + gid] = ((distance_squared / (radius[pid] * radius[pid])) <= 1.0) ? 1 : 0;
# }
# """

# # Build the program
# program = cl.Program(ctx, kernel_code).build()

# # Execute kernel
# program.mask_kernel(queue, result.shape, None, np.int32(ndim), np.int32(shape_flat), shape_buf, pos_buf, radius_buf, np.int32(include_edge), result_buf)

# # Retrieve results
# cl.enqueue_copy(queue, result, result_buf).wait()

# return result
return labels, pos

class IdentifyQWidget(QWidget):
Expand Down Expand Up @@ -186,7 +247,7 @@ def __init__(self, napari_viewer):
self.mass_slider = QSpinBox()
self.mass_slider.setRange(0, int(1e6))
self.mass_slider.setSingleStep(200)
self.mass_slider.setValue(4000)
self.mass_slider.setValue(25000)
l2 = QLabel()

l2.setText("Diameter of the particle")
Expand All @@ -203,7 +264,7 @@ def __init__(self, napari_viewer):
self.size_filter_input = QDoubleSpinBox()
self.size_filter_input.setRange(0, 10)
self.size_filter_input.setSingleStep(0.05)
self.size_filter_input.setValue(1.40)
self.size_filter_input.setValue(1.60)
self.layoutH0.addWidget(self.size_filter_tick)
self.layoutH0.addWidget(self.size_filter_input)

Expand Down Expand Up @@ -255,7 +316,7 @@ def __init__(self, napari_viewer):
self.layout_masks.addWidget(label_masks)
self.layout_masks.addWidget(self.make_masks_box)
self.layout_masks.addWidget(self.masks_option)

self.masks_dict = {0:'GPU',1:'CPU',2:'fast'}

btn = QPushButton("Identify Spots")
btn.clicked.connect(self._on_click)
Expand Down Expand Up @@ -418,15 +479,19 @@ def make_masks(self):
i = int(i)
temp = df[df['frame'] == i].sort_values(by=['y'])
#0 returns mask, 1 index returns coords
self.masks_dict[self.masks_option.currentIndex()]
print("Doing Masks with option:",self.masks_option.currentIndex(), self.masks_dict[self.masks_option.currentIndex()])
mask_temp, idx_temp = make_labels_trackpy_links(
self.viewer.layers[index_layer].data[i].shape,
# self.viewer.layers[index_layer].data[i],
temp,radius=((self.diameter_input.value())/2)-0.5)
temp,radius=((self.diameter_input.value())/2)-0.5,
# _round=False,
_algo=self.masks_dict[self.masks_option.currentIndex()],
)
# temp,size=5-1)
# print(mask_temp.max(),len(temp.index))
# idx.append(idx_temp)
mask_fixed = np.copy(mask_temp)

##this is needed when doing from links because all labels set as 'particles'
# for j in np.unique(mask_temp).astype('int'):
# if j != 0:
Expand Down Expand Up @@ -457,8 +522,8 @@ def _on_click(self):
print("Detected more than 3 dimensions")
if self.choice.isChecked():
print("Detected a Time lapse TYX or TZYX image")
img = np.asarray(self.viewer.layers[index_layer].data[self.min_timer.value():self.max_timer.value()])
# img = self.viewer.layers[index_layer].data[self.min_timer.value():self.max_timer.value()]
# img = np.asarray(self.viewer.layers[index_layer].data[self.min_timer.value():self.max_timer.value()])
img = self.viewer.layers[index_layer].data[self.min_timer.value():self.max_timer.value()]
self.f = tp.batch(img,self.diameter_input.value(),minmass=self.mass_slider.value(),
engine="numba",
processes=1,
Expand Down Expand Up @@ -503,7 +568,9 @@ def _on_click(self):
#transforming data to pandas ready for spots
if len(self.viewer.layers[index_layer].data.shape) <= 3:
#XYZ
if self.viewer.layers[index_layer].scale[0] != 1:
if len(self.viewer.layers[index_layer].data.shape) == 2:
_points = self.f.loc[:,['frame','y','x']]
elif self.viewer.layers[index_layer].scale[0] != 1:
_points = self.f.loc[:,['frame','z','y','x']]
#TYX PRJ
else:
Expand All @@ -518,25 +585,36 @@ def _on_click(self):
#like this is opposite color of the image
#make if smarter self.viewer.layers[index_layer].colormap.color has an array with the colors, we should be able to flip universaly
clr_name = self.viewer.layers[index_layer].colormap.name
if clr_name == 'green':
point_colors = 'magenta'
elif clr_name == 'red':
point_colors = 'cyan'
elif clr_name == 'blue':
point_colors = 'yellow'
clr_dict = {'green':'magenta',
'red':'cyan',
'blue':'yellow',
'magenta':'green',
'cyan':'red',
'yellow':'blue'}
# if clr_name == 'green':
# point_colors = 'magenta'
# elif clr_name == 'red':
# point_colors = 'cyan'
# elif clr_name == 'blue':
# point_colors = 'yellow'
point_colors = clr_dict[clr_name]
if len(_points) > 0:
self._points_layer = self.viewer.add_points(_points,name="Points "+name_points,properties=_metadata,**self.points_options,edge_color=point_colors)
self._points_layer.scale = self.viewer.layers[index_layer].scale

self.btn2.setEnabled(True)

#auto_save depends on the last created spots layer, if it's done after make_masks, it segfaults
#Keep the same order or redo the code
if self.auto_save.isChecked():
self._save_results()

if self.make_masks_box.isChecked():
_masks = self.make_masks()
self._masks_layer = self.viewer.add_labels(_masks)
self._masks_layer.scale = self.viewer.layers[index_layer].scale

self._points_layer = self.viewer.add_points(_points,name="Points "+name_points,properties=_metadata,**self.points_options,edge_color=point_colors)
self._points_layer.scale = self.viewer.layers[index_layer].scale

self.btn2.setEnabled(True)

if self.make_masks_box.isChecked():
_masks = self.make_masks()
self._masks_layer = self.viewer.add_labels(_masks)
self._masks_layer.scale = self.viewer.layers[index_layer].scale

if self.auto_save.isChecked():
self._save_results()

def batch_on_click(self):
print(self.batch_grid_layout.count())
Expand Down Expand Up @@ -664,6 +742,7 @@ def __init__(self, napari_viewer):
self.viewer.layers.events.removed.connect(self._enable_tracking)
self.viewer.layers.events.inserted.connect(self._enable_tracking)
self.viewer.layers.events.reordered.connect(self._enable_tracking)
self.viewer.layers.selection.events.changed.connect(self._enable_tracking)
#Selecting
# self.viewer.layers.events.selecting.connect(self._enable_tracking)

Expand Down Expand Up @@ -706,7 +785,15 @@ def _enable_tracking(self):
#if self.viewer.layers.selection.active.data == more than one time:
self.btn.setEnabled(True)


# def _get_choice_layer(self,_widget):
# for j,layer in enumerate(self.viewer.layers):
# if layer.name == _widget.currentText():
# index_layer = j
# break
# print("Layer where points are is:",j)
# return index_layer


def _track(self):
import pandas as pd
##if 2d
Expand All @@ -725,12 +812,16 @@ def _track(self):
links = tp.link(df, search_range=self.distance.value(),memory=self.memory.value())
if self.stubs_tick.isChecked():
links = tp.filter_stubs(links, self.stubs_input.value())
#if 2D:
_tracks = links.loc[:,['particle','frame','y','x']]
#if 3d:
# if 2D:
if 'z' in df:
_tracks = links.loc[:,['particle','frame','z','y','x']]
else:
_tracks = links.loc[:,['particle','frame','y','x']]
# if 3d:
# _tracks = links.loc[:,['particle','frame','z','y','x']]

self.viewer.add_tracks(_tracks,name='trackpy')
_tracks = self.viewer.add_tracks(_tracks,name='trackpy')
_tracks.scale = self.viewer.layers[0].scale[1:]
self.links = links
print(links)

Expand Down Expand Up @@ -808,6 +899,7 @@ def __init__(self, napari_viewer):
self.layout().addWidget(save_btn)

import pyqtgraph as pg

self._plt = pg.plot()
self.layout().addWidget(self._plt)

Expand Down Expand Up @@ -852,7 +944,7 @@ def open_file_dialog(self):
def calculate_colocalizing(self):
#makecode from notebooks
import scipy

import pyqtgraph as pg
from sklearn.neighbors import KDTree
print("Doing Colocalization")
QuestionPOS = self._get_points(self.points_question)
Expand Down Expand Up @@ -881,7 +973,10 @@ def calculate_colocalizing(self):
line1 = self._plt.plot(
_bins[:-1],_hist,
# distances_list,
pen='g',
# pen=pg.mkPen(pg.intColor(np.random.randint(16), 16)),
pen=(np.random.randint(16), 16),
# pen='r'
# pen=self.points_question.colormap.name
# symbol='x', symbolPen='g',
# symbolBrush=0.2,
name='green'
Expand Down Expand Up @@ -914,6 +1009,10 @@ def calculate_colocalizing(self):
self)+'_'+coloc_name+"_Spots.csv")
self._save_results()

# return {coloc_name:[AnchorPOS,_colocalizing])}
# return spots count Filename what channel is anchor which is question and numbers?
#save directly to a file that is all.csv or file_coloc_counts.csv?

def calculate_all_colocalizing(self):
idx_anchor = self.points_anchor.currentIndex()
_num_items = self.points_question.count()
Expand Down

0 comments on commit 0120c27

Please sign in to comment.