-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
281 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
The bigfish.deep_learning module includes deep learning models and routines. | ||
""" | ||
|
||
from .squeezenet import SqueezeNet0, SqueezeNet_qbi | ||
|
||
|
||
_squeezenet = ["SqueezeNet0", "SqueezeNet_qbi"] | ||
|
||
__all__ = _squeezenet |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,267 @@ | ||
from optparse import OptionParser | ||
from skimage.measure import label | ||
from sklearn.metrics import accuracy_score, roc_auc_score | ||
from sklearn.metrics import jaccard_similarity_score, f1_score | ||
from sklearn.metrics import recall_score, precision_score, confusion_matrix | ||
from skimage.morphology import erosion, disk | ||
from os.path import join | ||
import os | ||
from skimage.io import imsave, imread | ||
import numpy as np | ||
import pdb | ||
import time | ||
from progressbar import ProgressBar | ||
from postproc.postprocessing import PostProcess, generate_wsl | ||
|
||
|
||
def GetOptions(): | ||
""" | ||
Defines most of the options needed | ||
""" | ||
parser = OptionParser() | ||
parser.add_option("--tf_record", dest="TFRecord", type="string", | ||
default="", | ||
help="Where to find the TFrecord file") | ||
parser.add_option("--path", dest="path", type="string", | ||
help="Where to collect the patches") | ||
parser.add_option("--size_train", dest="size_train", type="int", | ||
help="size of the input image to the network") | ||
parser.add_option("--log", dest="log", | ||
help="log dir") | ||
parser.add_option("--learning_rate", dest="lr", type="float", default=0.01, | ||
help="learning_rate") | ||
parser.add_option("--batch_size", dest="bs", type="int", default=1, | ||
help="batch size") | ||
parser.add_option("--epoch", dest="epoch", type="int", default=1, | ||
help="number of epochs") | ||
parser.add_option("--n_features", dest="n_features", type="int", | ||
help="number of channels on first layers") | ||
parser.add_option("--weight_decay", dest="weight_decay", type="float", | ||
default=0.00005, | ||
help="weight decay value") | ||
parser.add_option("--dropout", dest="dropout", type="float", | ||
default=0.5, | ||
help="dropout value to apply to the FC layers.") | ||
parser.add_option("--mean_file", dest="mean_file", type="str", | ||
help="where to find the mean file to substract to the original image.") | ||
parser.add_option('--n_threads', dest="THREADS", type=int, default=100, | ||
help="number of threads to use for the preprocessing.") | ||
parser.add_option('--crop', dest="crop", type=int, default=4, | ||
help="crop size depending on validation/test/train phase.") | ||
parser.add_option('--split', dest="split", type="str", | ||
help="validation/test/train phase.") | ||
parser.add_option('--p1', dest="p1", type="int", | ||
help="1st input for post processing.") | ||
parser.add_option('--p2', dest="p2", type="float", | ||
help="2nd input for post processing.") | ||
parser.add_option('--iters', dest="iters", type="int") | ||
parser.add_option('--seed', dest="seed", type="int") | ||
parser.add_option('--size_test', dest="size_test", type="int") | ||
parser.add_option('--restore', dest="restore", type="str") | ||
parser.add_option('--save_path', dest="save_path", type="str", default=".") | ||
parser.add_option('--type', dest="type", type="str", | ||
help="Type for the datagen") | ||
parser.add_option('--UNet', dest='UNet', action='store_true') | ||
parser.add_option('--no-UNet', dest='UNet', action='store_false') | ||
parser.add_option('--output', dest="output", type="str") | ||
parser.add_option('--output_csv', dest="output_csv", type="str") | ||
|
||
(options, args) = parser.parse_args() | ||
|
||
return options | ||
|
||
|
||
def ComputeMetrics(prob, batch_labels, p1, p2, rgb=None, save_path=None, | ||
ind=0): | ||
""" | ||
Computes all metrics between probability map and corresponding label. | ||
If you give also an rgb image it will save many extra meta data image. | ||
""" | ||
GT = label(batch_labels.copy()) | ||
PRED = PostProcess(prob, p1, p2) | ||
# PRED = label((prob > 0.5).astype('uint8')) | ||
lbl = GT.copy() | ||
pred = PRED.copy() | ||
aji = AJI_fast(lbl, pred) | ||
lbl[lbl > 0] = 1 | ||
pred[pred > 0] = 1 | ||
l, p = lbl.flatten(), pred.flatten() | ||
acc = accuracy_score(l, p) | ||
roc = roc_auc_score(l, p) | ||
jac = jaccard_similarity_score(l, p) | ||
f1 = f1_score(l, p) | ||
recall = recall_score(l, p) | ||
precision = precision_score(l, p) | ||
if rgb is not None: | ||
xval_n = join(save_path, "xval_{}.png").format(ind) | ||
yval_n = join(save_path, "yval_{}.png").format(ind) | ||
prob_n = join(save_path, "prob_{}.png").format(ind) | ||
pred_n = join(save_path, "pred_{}.png").format(ind) | ||
c_gt_n = join(save_path, "C_gt_{}.png").format(ind) | ||
c_pr_n = join(save_path, "C_pr_{}.png").format(ind) | ||
|
||
imsave(xval_n, rgb) | ||
imsave(yval_n, color_bin(GT)) | ||
imsave(prob_n, prob) | ||
imsave(pred_n, color_bin(PRED)) | ||
imsave(c_gt_n, add_contours(rgb, GT)) | ||
imsave(c_pr_n, add_contours(rgb, PRED)) | ||
|
||
return acc, roc, jac, recall, precision, f1, aji | ||
|
||
|
||
def color_bin(bin_labl): | ||
""" | ||
Colors bin image so that nuclei come out nicer. | ||
""" | ||
dim = bin_labl.shape | ||
x, y = dim[0], dim[1] | ||
res = np.zeros(shape=(x, y, 3)) | ||
for i in range(1, bin_labl.max() + 1): | ||
rgb = np.random.normal(loc=125, scale=100, size=3) | ||
rgb[rgb < 0] = 0 | ||
rgb[rgb > 255] = 255 | ||
rgb = rgb.astype(np.uint8) | ||
res[bin_labl == i] = rgb | ||
return res.astype(np.uint8) | ||
|
||
|
||
def add_contours(rgb_image, contour, ds=2): | ||
""" | ||
Adds contours to images. | ||
The image has to be a binary image | ||
""" | ||
rgb = rgb_image.copy() | ||
contour[contour > 0] = 1 | ||
boundery = contour - erosion(contour, disk(ds)) | ||
rgb[boundery > 0] = np.array([0, 0, 0]) | ||
return rgb | ||
|
||
|
||
def CheckOrCreate(path): | ||
""" | ||
If path exists, does nothing otherwise it creates it. | ||
""" | ||
if not os.path.isdir(path): | ||
os.makedirs(path) | ||
|
||
|
||
def Intersection(A, B): | ||
""" | ||
Returns the pixel count corresponding to the intersection | ||
between A and B. | ||
""" | ||
C = A + B | ||
C[C != 2] = 0 | ||
C[C == 2] = 1 | ||
return C | ||
|
||
|
||
def Union(A, B): | ||
""" | ||
Returns the pixel count corresponding to the union | ||
between A and B. | ||
""" | ||
C = A + B | ||
C[C > 0] = 1 | ||
return C | ||
|
||
|
||
def AssociatedCell(G_i, S): | ||
""" | ||
Returns the indice of the associated prediction cell for a certain | ||
ground truth element. Maybe do something if no associated cell in the | ||
prediction mask touches the GT. | ||
""" | ||
|
||
def g(indice): | ||
S_indice = np.zeros_like(S) | ||
S_indice[S == indice] = 1 | ||
NUM = float(Intersection(G_i, S_indice).sum()) | ||
DEN = float(Union(G_i, S_indice).sum()) | ||
return NUM / DEN | ||
|
||
res = map(g, range(1, S.max() + 1)) | ||
indice = np.array(res).argmax() + 1 | ||
return indice | ||
|
||
|
||
pbar = ProgressBar() | ||
|
||
|
||
def AJI(G, S): | ||
""" | ||
AJI as described in the paper, AJI is more abstract implementation but 100times faster. | ||
""" | ||
G = label(G, background=0) | ||
S = label(S, background=0) | ||
|
||
C = 0 | ||
U = 0 | ||
USED = np.zeros(S.max()) | ||
|
||
for i in pbar(range(1, G.max() + 1)): | ||
only_ground_truth = np.zeros_like(G) | ||
only_ground_truth[G == i] = 1 | ||
j = AssociatedCell(only_ground_truth, S) | ||
only_prediction = np.zeros_like(S) | ||
only_prediction[S == j] = 1 | ||
C += Intersection(only_prediction, only_ground_truth).sum() | ||
U += Union(only_prediction, only_ground_truth).sum() | ||
USED[j - 1] = 1 | ||
|
||
def h(indice): | ||
if USED[indice - 1] == 1: | ||
return 0 | ||
else: | ||
only_prediction = np.zeros_like(S) | ||
only_prediction[S == indice] = 1 | ||
return only_prediction.sum() | ||
|
||
U_sum = map(h, range(1, S.max() + 1)) | ||
U += np.sum(U_sum) | ||
return float(C) / float(U) | ||
|
||
|
||
def AJI_fast(G, S): | ||
""" | ||
AJI as described in the paper, but a much faster implementation. | ||
""" | ||
G = label(G, background=0) | ||
S = label(S, background=0) | ||
if S.sum() == 0: | ||
return 0. | ||
C = 0 | ||
U = 0 | ||
USED = np.zeros(S.max()) | ||
|
||
G_flat = G.flatten() | ||
S_flat = S.flatten() | ||
G_max = np.max(G_flat) | ||
S_max = np.max(S_flat) | ||
m_labels = max(G_max, S_max) + 1 | ||
cm = confusion_matrix(G_flat, S_flat, labels=range(m_labels)).astype( | ||
np.float) | ||
LIGNE_J = np.zeros(S_max) | ||
for j in range(1, S_max + 1): | ||
LIGNE_J[j - 1] = cm[:, j].sum() | ||
|
||
for i in range(1, G_max + 1): | ||
LIGNE_I_sum = cm[i, :].sum() | ||
|
||
def h(indice): | ||
LIGNE_J_sum = LIGNE_J[indice - 1] | ||
inter = cm[i, indice] | ||
|
||
union = LIGNE_I_sum + LIGNE_J_sum - inter | ||
return inter / union | ||
|
||
JI_ligne = map(h, range(1, S_max + 1)) | ||
best_indice = np.argmax(JI_ligne) + 1 | ||
C += cm[i, best_indice] | ||
U += LIGNE_J[best_indice - 1] + LIGNE_I_sum - cm[i, best_indice] | ||
USED[best_indice - 1] = 1 | ||
|
||
U_sum = ((1 - USED) * LIGNE_J).sum() | ||
U += U_sum | ||
return float(C) / float(U) |