Skip to content

Commit

Permalink
isolate deep learning code
Browse files Browse the repository at this point in the history
  • Loading branch information
Henley13 committed May 5, 2020
1 parent 3c5aa70 commit 11f3c57
Show file tree
Hide file tree
Showing 7 changed files with 281 additions and 5 deletions.
5 changes: 1 addition & 4 deletions bigfish/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from .input_preparation import (prepare_coordinate_data,
build_boundaries_layers, build_surface_layers,
build_distance_layers, Generator)
from .squeezenet import SqueezeNet0, SqueezeNet_qbi
from .features import get_features, get_features_name

# ### Load models ###
Expand All @@ -19,6 +18,4 @@
"build_surface_layers", "build_distance_layers",
"Generator"]

_squeezenet = ["SqueezeNet0", "SqueezeNet_qbi"]

__all__ = _features + _input_preparation + _squeezenet
__all__ = _features + _input_preparation
Empty file.
12 changes: 12 additions & 0 deletions bigfish/deep_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-

"""
The bigfish.deep_learning module includes deep learning models and routines.
"""

from .squeezenet import SqueezeNet0, SqueezeNet_qbi


_squeezenet = ["SqueezeNet0", "SqueezeNet_qbi"]

__all__ = _squeezenet
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import tensorflow as tf
import numpy as np

from .base import BaseModel, get_optimizer
from bigfish.deep_learning.base import BaseModel, get_optimizer

from tensorflow.python.keras.backend import function, learning_phase
from tensorflow.python.keras.models import Model
Expand Down
File renamed without changes.
267 changes: 267 additions & 0 deletions bigfish/deep_learning/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
from optparse import OptionParser
from skimage.measure import label
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.metrics import jaccard_similarity_score, f1_score
from sklearn.metrics import recall_score, precision_score, confusion_matrix
from skimage.morphology import erosion, disk
from os.path import join
import os
from skimage.io import imsave, imread
import numpy as np
import pdb
import time
from progressbar import ProgressBar
from postproc.postprocessing import PostProcess, generate_wsl


def GetOptions():
"""
Defines most of the options needed
"""
parser = OptionParser()
parser.add_option("--tf_record", dest="TFRecord", type="string",
default="",
help="Where to find the TFrecord file")
parser.add_option("--path", dest="path", type="string",
help="Where to collect the patches")
parser.add_option("--size_train", dest="size_train", type="int",
help="size of the input image to the network")
parser.add_option("--log", dest="log",
help="log dir")
parser.add_option("--learning_rate", dest="lr", type="float", default=0.01,
help="learning_rate")
parser.add_option("--batch_size", dest="bs", type="int", default=1,
help="batch size")
parser.add_option("--epoch", dest="epoch", type="int", default=1,
help="number of epochs")
parser.add_option("--n_features", dest="n_features", type="int",
help="number of channels on first layers")
parser.add_option("--weight_decay", dest="weight_decay", type="float",
default=0.00005,
help="weight decay value")
parser.add_option("--dropout", dest="dropout", type="float",
default=0.5,
help="dropout value to apply to the FC layers.")
parser.add_option("--mean_file", dest="mean_file", type="str",
help="where to find the mean file to substract to the original image.")
parser.add_option('--n_threads', dest="THREADS", type=int, default=100,
help="number of threads to use for the preprocessing.")
parser.add_option('--crop', dest="crop", type=int, default=4,
help="crop size depending on validation/test/train phase.")
parser.add_option('--split', dest="split", type="str",
help="validation/test/train phase.")
parser.add_option('--p1', dest="p1", type="int",
help="1st input for post processing.")
parser.add_option('--p2', dest="p2", type="float",
help="2nd input for post processing.")
parser.add_option('--iters', dest="iters", type="int")
parser.add_option('--seed', dest="seed", type="int")
parser.add_option('--size_test', dest="size_test", type="int")
parser.add_option('--restore', dest="restore", type="str")
parser.add_option('--save_path', dest="save_path", type="str", default=".")
parser.add_option('--type', dest="type", type="str",
help="Type for the datagen")
parser.add_option('--UNet', dest='UNet', action='store_true')
parser.add_option('--no-UNet', dest='UNet', action='store_false')
parser.add_option('--output', dest="output", type="str")
parser.add_option('--output_csv', dest="output_csv", type="str")

(options, args) = parser.parse_args()

return options


def ComputeMetrics(prob, batch_labels, p1, p2, rgb=None, save_path=None,
ind=0):
"""
Computes all metrics between probability map and corresponding label.
If you give also an rgb image it will save many extra meta data image.
"""
GT = label(batch_labels.copy())
PRED = PostProcess(prob, p1, p2)
# PRED = label((prob > 0.5).astype('uint8'))
lbl = GT.copy()
pred = PRED.copy()
aji = AJI_fast(lbl, pred)
lbl[lbl > 0] = 1
pred[pred > 0] = 1
l, p = lbl.flatten(), pred.flatten()
acc = accuracy_score(l, p)
roc = roc_auc_score(l, p)
jac = jaccard_similarity_score(l, p)
f1 = f1_score(l, p)
recall = recall_score(l, p)
precision = precision_score(l, p)
if rgb is not None:
xval_n = join(save_path, "xval_{}.png").format(ind)
yval_n = join(save_path, "yval_{}.png").format(ind)
prob_n = join(save_path, "prob_{}.png").format(ind)
pred_n = join(save_path, "pred_{}.png").format(ind)
c_gt_n = join(save_path, "C_gt_{}.png").format(ind)
c_pr_n = join(save_path, "C_pr_{}.png").format(ind)

imsave(xval_n, rgb)
imsave(yval_n, color_bin(GT))
imsave(prob_n, prob)
imsave(pred_n, color_bin(PRED))
imsave(c_gt_n, add_contours(rgb, GT))
imsave(c_pr_n, add_contours(rgb, PRED))

return acc, roc, jac, recall, precision, f1, aji


def color_bin(bin_labl):
"""
Colors bin image so that nuclei come out nicer.
"""
dim = bin_labl.shape
x, y = dim[0], dim[1]
res = np.zeros(shape=(x, y, 3))
for i in range(1, bin_labl.max() + 1):
rgb = np.random.normal(loc=125, scale=100, size=3)
rgb[rgb < 0] = 0
rgb[rgb > 255] = 255
rgb = rgb.astype(np.uint8)
res[bin_labl == i] = rgb
return res.astype(np.uint8)


def add_contours(rgb_image, contour, ds=2):
"""
Adds contours to images.
The image has to be a binary image
"""
rgb = rgb_image.copy()
contour[contour > 0] = 1
boundery = contour - erosion(contour, disk(ds))
rgb[boundery > 0] = np.array([0, 0, 0])
return rgb


def CheckOrCreate(path):
"""
If path exists, does nothing otherwise it creates it.
"""
if not os.path.isdir(path):
os.makedirs(path)


def Intersection(A, B):
"""
Returns the pixel count corresponding to the intersection
between A and B.
"""
C = A + B
C[C != 2] = 0
C[C == 2] = 1
return C


def Union(A, B):
"""
Returns the pixel count corresponding to the union
between A and B.
"""
C = A + B
C[C > 0] = 1
return C


def AssociatedCell(G_i, S):
"""
Returns the indice of the associated prediction cell for a certain
ground truth element. Maybe do something if no associated cell in the
prediction mask touches the GT.
"""

def g(indice):
S_indice = np.zeros_like(S)
S_indice[S == indice] = 1
NUM = float(Intersection(G_i, S_indice).sum())
DEN = float(Union(G_i, S_indice).sum())
return NUM / DEN

res = map(g, range(1, S.max() + 1))
indice = np.array(res).argmax() + 1
return indice


pbar = ProgressBar()


def AJI(G, S):
"""
AJI as described in the paper, AJI is more abstract implementation but 100times faster.
"""
G = label(G, background=0)
S = label(S, background=0)

C = 0
U = 0
USED = np.zeros(S.max())

for i in pbar(range(1, G.max() + 1)):
only_ground_truth = np.zeros_like(G)
only_ground_truth[G == i] = 1
j = AssociatedCell(only_ground_truth, S)
only_prediction = np.zeros_like(S)
only_prediction[S == j] = 1
C += Intersection(only_prediction, only_ground_truth).sum()
U += Union(only_prediction, only_ground_truth).sum()
USED[j - 1] = 1

def h(indice):
if USED[indice - 1] == 1:
return 0
else:
only_prediction = np.zeros_like(S)
only_prediction[S == indice] = 1
return only_prediction.sum()

U_sum = map(h, range(1, S.max() + 1))
U += np.sum(U_sum)
return float(C) / float(U)


def AJI_fast(G, S):
"""
AJI as described in the paper, but a much faster implementation.
"""
G = label(G, background=0)
S = label(S, background=0)
if S.sum() == 0:
return 0.
C = 0
U = 0
USED = np.zeros(S.max())

G_flat = G.flatten()
S_flat = S.flatten()
G_max = np.max(G_flat)
S_max = np.max(S_flat)
m_labels = max(G_max, S_max) + 1
cm = confusion_matrix(G_flat, S_flat, labels=range(m_labels)).astype(
np.float)
LIGNE_J = np.zeros(S_max)
for j in range(1, S_max + 1):
LIGNE_J[j - 1] = cm[:, j].sum()

for i in range(1, G_max + 1):
LIGNE_I_sum = cm[i, :].sum()

def h(indice):
LIGNE_J_sum = LIGNE_J[indice - 1]
inter = cm[i, indice]

union = LIGNE_I_sum + LIGNE_J_sum - inter
return inter / union

JI_ligne = map(h, range(1, S_max + 1))
best_indice = np.argmax(JI_ligne) + 1
C += cm[i, best_indice]
U += LIGNE_J[best_indice - 1] + LIGNE_I_sum - cm[i, best_indice]
USED[best_indice - 1] = 1

U_sum = ((1 - USED) * LIGNE_J).sum()
U += U_sum
return float(C) / float(U)

0 comments on commit 11f3c57

Please sign in to comment.