Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ultrasound confidence map to transforms #6709

Merged
merged 28 commits into from
Jul 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
11a1399
Add ultrasound confidence map to transforms
MrGranddy Jul 12, 2023
cc0c3ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2023
14b7af0
Add scipy solver, move the main functionality to data
MrGranddy Jul 19, 2023
9f5b9db
Add scipy solver, move the main functionality to data
MrGranddy Jul 19, 2023
b855917
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2023
c4ab0b7
Change backend parameter into solver_backend
MrGranddy Jul 19, 2023
43a9b96
Merge branch 'dev' of github.com:MrGranddy/MONAI into dev
MrGranddy Jul 19, 2023
c954220
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2023
667dcd0
DCO Remediation Commit for Vahit Bugra YESILKAYNAK <bugrayesilkaynak@…
MrGranddy Jul 19, 2023
61e53f6
Merge branch 'dev' of github.com:MrGranddy/MONAI into dev
MrGranddy Jul 19, 2023
052ad64
Auto format
MrGranddy Jul 19, 2023
042f558
Break long comment lines
MrGranddy Jul 19, 2023
1beb54e
Remove octave, add scipy dependencies
MrGranddy Jul 19, 2023
9ae4782
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2023
9af5fcd
Fix scipy dependency
MrGranddy Jul 19, 2023
29ad20f
Fix scipy dependency
MrGranddy Jul 19, 2023
3b1cc9e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2023
518fdda
Fix unittests
MrGranddy Jul 19, 2023
8067935
Merge branch 'dev' of github.com:MrGranddy/MONAI into dev
MrGranddy Jul 19, 2023
4540069
Add unittest to min_tests.py excluded tests (requires scipy)
MrGranddy Jul 19, 2023
f702071
Make everything flake8 proper
MrGranddy Jul 19, 2023
32edd44
Black formatting
MrGranddy Jul 19, 2023
3bb279e
MyPy formatting
MrGranddy Jul 19, 2023
2c9dd0c
Merge branch 'dev' into dev
wyli Jul 19, 2023
bc4d121
Change input shape into required format, add the unit tests
MrGranddy Jul 23, 2023
0da54e6
Merge branch 'dev' of github.com:MrGranddy/MONAI into dev
MrGranddy Jul 23, 2023
02757a7
Fix Flake8 function name convention
MrGranddy Jul 23, 2023
f388e3d
Merge branch 'dev' into dev
wyli Jul 23, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ itk>=5.2
nibabel
parameterized
scikit-image>=0.19.0
scipy>=1.7.1
tensorboard
commonmark==0.9.1
recommonmark==0.6.0
Expand Down
8 changes: 4 additions & 4 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
- [Uninstall the packages](#uninstall-the-packages)
- [From conda-forge](#from-conda-forge)
- [From GitHub](#from-github)
- [Option 1 (as a part of your system-wide module)](#option-1-as-a-part-of-your-system-wide-module)
- [Option 2 (editable installation)](#option-2-editable-installation)
- [Option 1 (as a part of your system-wide module):](#option-1-as-a-part-of-your-system-wide-module)
- [Option 2 (editable installation):](#option-2-editable-installation)
- [Validating the install](#validating-the-install)
- [MONAI version string](#monai-version-string)
- [From DockerHub](#from-dockerhub)
Expand Down Expand Up @@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
- The options are

```
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr]
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr]
```

which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`,
which correspond to `nibabel`, `scikit-image`, `scipy`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, and `zarr` respectively.

- `pip install 'monai[all]'` installs all the optional dependencies.
1 change: 1 addition & 0 deletions monai/config/deviceconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def get_optional_config_values():
output["ITK"] = get_package_version("itk")
output["Nibabel"] = get_package_version("nibabel")
output["scikit-image"] = get_package_version("skimage")
output["scipy"] = get_package_version("scipy")
output["Pillow"] = get_package_version("PIL")
output["Tensorboard"] = get_package_version("tensorboard")
output["gdown"] = get_package_version("gdown")
Expand Down
2 changes: 2 additions & 0 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,5 @@ def reduce_meta_tensor(meta_tensor):
return _rebuild_meta, (type(meta_tensor), storage, dtype, metadata)

ForkingPickler.register(MetaTensor, reduce_meta_tensor)

from .ultrasound_confidence_map import UltrasoundConfidenceMap
352 changes: 352 additions & 0 deletions monai/data/ultrasound_confidence_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,352 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import numpy as np
from numpy.typing import NDArray

from monai.utils import min_version, optional_import

__all__ = ["UltrasoundConfidenceMap"]

cv2, _ = optional_import("cv2")
csc_matrix, _ = optional_import("scipy.sparse", "1.7.1", min_version, "csc_matrix")
spsolve, _ = optional_import("scipy.sparse.linalg", "1.7.1", min_version, "spsolve")
hilbert, _ = optional_import("scipy.signal", "1.7.1", min_version, "hilbert")


class UltrasoundConfidenceMap:
"""Compute confidence map from an ultrasound image.
This transform uses the method introduced by Karamalis et al. in https://doi.org/10.1016/j.media.2012.07.005.
It generates a confidence map by setting source and sink points in the image and computing the probability
for random walks to reach the source for each pixel.

Args:
alpha (float, optional): Alpha parameter. Defaults to 2.0.
beta (float, optional): Beta parameter. Defaults to 90.0.
gamma (float, optional): Gamma parameter. Defaults to 0.05.
mode (str, optional): 'RF' or 'B' mode data. Defaults to 'B'.
sink_mode (str, optional): Sink mode. Defaults to 'all'. If 'mask' is selected, a mask must be when calling
the transform. Can be 'all', 'mid', 'min', or 'mask'.
"""

def __init__(self, alpha: float = 2.0, beta: float = 90.0, gamma: float = 0.05, mode="B", sink_mode="all"):
# The hyperparameters for confidence map estimation
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.mode = mode
self.sink_mode = sink_mode

# The precision to use for all computations
self.eps = np.finfo("float64").eps

# Store sink indices for external use
self._sink_indices = np.array([], dtype="float64")

def sub2ind(self, size: tuple[int, ...], rows: NDArray, cols: NDArray) -> NDArray:
"""Converts row and column subscripts into linear indices,
basically the copy of the MATLAB function of the same name.
https://www.mathworks.com/help/matlab/ref/sub2ind.html

This function is Pythonic so the indices start at 0.

Args:
size Tuple[int]: Size of the matrix
rows (NDArray): Row indices
cols (NDArray): Column indices

Returns:
indices (NDArray): 1-D array of linear indices
"""
indices: NDArray = rows + cols * size[0]
return indices

def get_seed_and_labels(
self, data: NDArray, sink_mode: str = "all", sink_mask: NDArray | None = None
) -> tuple[NDArray, NDArray]:
"""Get the seed and label arrays for the max-flow algorithm

Args:
data: Input array
sink_mode (str, optional): Sink mode. Defaults to 'all'.
sink_mask (NDArray, optional): Sink mask. Defaults to None.

Returns:
Tuple[NDArray, NDArray]: Seed and label arrays
"""

# Seeds and labels (boundary conditions)
seeds = np.array([], dtype="float64")
labels = np.array([], dtype="float64")

# Indices for all columns
sc = np.arange(data.shape[1], dtype="float64")

# SOURCE ELEMENTS - 1st matrix row
# Indices for 1st row, it will be broadcasted with sc
sr_up = np.array([0])
seed = self.sub2ind(data.shape, sr_up, sc).astype("float64")
seed = np.unique(seed)
seeds = np.concatenate((seeds, seed))

# Label 1
label = np.ones_like(seed)
labels = np.concatenate((labels, label))

# Create seeds for sink elements

if sink_mode == "all":
# All elements in the last row
sr_down = np.ones_like(sc) * (data.shape[0] - 1)
self._sink_indices = np.array([sr_down, sc], dtype="int32")
seed = self.sub2ind(data.shape, sr_down, sc).astype("float64")

elif sink_mode == "mid":
# Middle element in the last row
sc_down = np.array([data.shape[1] // 2])
sr_down = np.ones_like(sc_down) * (data.shape[0] - 1)
self._sink_indices = np.array([sr_down, sc_down], dtype="int32")
seed = self.sub2ind(data.shape, sr_down, sc_down).astype("float64")

elif sink_mode == "min":
# Minimum element in the last row (excluding 10% from the edges)
ten_percent = int(data.shape[1] * 0.1)
min_val = np.min(data[-1, ten_percent:-ten_percent])
min_idxs = np.where(data[-1, ten_percent:-ten_percent] == min_val)[0] + ten_percent
sc_down = min_idxs
sr_down = np.ones_like(sc_down) * (data.shape[0] - 1)
self._sink_indices = np.array([sr_down, sc_down], dtype="int32")
seed = self.sub2ind(data.shape, sr_down, sc_down).astype("float64")

elif sink_mode == "mask":
# All elements in the mask
coords = np.where(sink_mask != 0)
sr_down = coords[0]
sc_down = coords[1]
self._sink_indices = np.array([sr_down, sc_down], dtype="int32")
seed = self.sub2ind(data.shape, sr_down, sc_down).astype("float64")

seed = np.unique(seed)
seeds = np.concatenate((seeds, seed))

# Label 2
label = np.ones_like(seed) * 2
labels = np.concatenate((labels, label))

return seeds, labels

def normalize(self, inp: NDArray) -> NDArray:
"""Normalize an array to [0, 1]"""
normalized_array: NDArray = (inp - np.min(inp)) / (np.ptp(inp) + self.eps)
return normalized_array

def attenuation_weighting(self, img: NDArray, alpha: float) -> NDArray:
"""Compute attenuation weighting

Args:
img (NDArray): Image
alpha: Attenuation coefficient (see publication)

Returns:
w (NDArray): Weighting expressing depth-dependent attenuation
"""

# Create depth vector and repeat it for each column
dw = np.linspace(0, 1, img.shape[0], dtype="float64")
dw = np.tile(dw.reshape(-1, 1), (1, img.shape[1]))

w: NDArray = 1.0 - np.exp(-alpha * dw) # Compute exp inline

return w

def confidence_laplacian(self, padded_index: NDArray, padded_image: NDArray, beta: float, gamma: float):
"""Compute 6-Connected Laplacian for confidence estimation problem

Args:
padded_index (NDArray): The index matrix of the image with boundary padding.
padded_image (NDArray): The padded image.
beta (float): Random walks parameter that defines the sensitivity of the Gaussian weighting function.
gamma (float): Horizontal penalty factor that adjusts the weight of horizontal edges in the Laplacian.

Returns:
L (csc_matrix): The 6-connected Laplacian matrix used for confidence map estimation.
"""

m, _ = padded_index.shape

padded_index = padded_index.T.flatten()
padded_image = padded_image.T.flatten()

p = np.where(padded_index > 0)[0]

i = padded_index[p] - 1 # Index vector
j = padded_index[p] - 1 # Index vector
# Entries vector, initially for diagonal
s = np.zeros_like(p, dtype="float64")

edge_templates = [
-1, # Vertical edges
1,
m - 1, # Diagonal edges
m + 1,
-m - 1,
-m + 1,
m, # Horizontal edges
-m,
]

vertical_end = None

for iter_idx, k in enumerate(edge_templates):
neigh_idxs = padded_index[p + k]

q = np.where(neigh_idxs > 0)[0]

ii = padded_index[p[q]] - 1
i = np.concatenate((i, ii))
jj = neigh_idxs[q] - 1
j = np.concatenate((j, jj))
w = np.abs(padded_image[p[ii]] - padded_image[p[jj]]) # Intensity derived weight
s = np.concatenate((s, w))

if iter_idx == 1:
vertical_end = s.shape[0] # Vertical edges length
elif iter_idx == 5:
s.shape[0] # Diagonal edges length

# Normalize weights
s = self.normalize(s)

# Horizontal penalty
s[:vertical_end] += gamma
# s[vertical_end:diagonal_end] += gamma * np.sqrt(2) # --> In the paper it is sqrt(2)
# since the diagonal edges are longer yet does not exist in the original code

# Normalize differences
s = self.normalize(s)

# Gaussian weighting function
s = -(
(np.exp(-beta * s, dtype="float64")) + 1.0e-6
) # --> This epsilon changes results drastically default: 1.e-6

# Create Laplacian, diagonal missing
lap = csc_matrix((s, (i, j)))

# Reset diagonal weights to zero for summing
# up the weighted edge degree in the next step
lap.setdiag(0)

# Weighted edge degree
diag = np.abs(lap.sum(axis=0).A)[0]

# Finalize Laplacian by completing the diagonal
lap.setdiag(diag)

return lap

def _solve_linear_system(self, lap, rhs):
x = spsolve(lap, rhs)

return x

def confidence_estimation(self, img, seeds, labels, beta, gamma):
"""Compute confidence map

Args:
img (NDArray): Processed image.
seeds (NDArray): Seeds for the random walks framework. These are indices of the source and sink nodes.
labels (NDArray): Labels for the random walks framework. These represent the classes or groups of the seeds.
beta: Random walks parameter that defines the sensitivity of the Gaussian weighting function.
gamma: Horizontal penalty factor that adjusts the weight of horizontal edges in the Laplacian.

Returns:
map: Confidence map which shows the probability of each pixel belonging to the source or sink group.
"""

# Index matrix with boundary padding
idx = np.arange(1, img.shape[0] * img.shape[1] + 1).reshape(img.shape[1], img.shape[0]).T
pad = 1

padded_idx = np.pad(idx, (pad, pad), "constant", constant_values=(0, 0))
padded_img = np.pad(img, (pad, pad), "constant", constant_values=(0, 0))

# Laplacian
lap = self.confidence_laplacian(padded_idx, padded_img, beta, gamma)

# Select marked columns from Laplacian to create L_M and B^T
b = lap[:, seeds]

# Select marked nodes to create B^T
n = np.sum(padded_idx > 0).item()
i_u = np.setdiff1d(np.arange(n), seeds.astype(int)) # Index of unmarked nodes
b = b[i_u, :]

# Remove marked nodes from Laplacian by deleting rows and cols
keep_indices = np.setdiff1d(np.arange(lap.shape[0]), seeds)
lap = csc_matrix(lap[keep_indices, :][:, keep_indices])

# Define M matrix
m = np.zeros((seeds.shape[0], 1), dtype="float64")
m[:, 0] = labels == 1

# Right-handside (-B^T*M)
rhs = -b @ m # type: ignore

# Solve linear system
x = self._solve_linear_system(lap, rhs)

# Prepare output
probabilities = np.zeros((n,), dtype="float64")
# Probabilities for unmarked nodes
probabilities[i_u] = x
# Max probability for marked node
probabilities[seeds[labels == 1].astype(int)] = 1.0

# Final reshape with same size as input image (no padding)
probabilities = probabilities.reshape((img.shape[1], img.shape[0])).T

return probabilities

def __call__(self, data: NDArray, sink_mask: NDArray | None = None) -> NDArray:
"""Compute the confidence map

Args:
data (NDArray): RF ultrasound data (one scanline per column) [H x W] 2D array

Returns:
map (NDArray): Confidence map [H x W] 2D array
"""

# Normalize data
data = data.astype("float64")
data = self.normalize(data)

if self.mode == "RF":
# MATLAB hilbert applies the Hilbert transform to columns
data = np.abs(hilbert(data, axis=0)).astype("float64") # type: ignore

seeds, labels = self.get_seed_and_labels(data, self.sink_mode, sink_mask)

# Attenuation with Beer-Lambert
w = self.attenuation_weighting(data, self.alpha)

# Apply weighting directly to image
# Same as applying it individually during the formation of the
# Laplacian
data = data * w

# Find condidence values
map_: NDArray = self.confidence_estimation(data, seeds, labels, self.beta, self.gamma)

return map_
Loading
Loading