From 2ed005313d37caad87409ccfddb3918726325841 Mon Sep 17 00:00:00 2001 From: Shannon Axelrod Date: Tue, 8 Oct 2019 15:28:29 -0700 Subject: [PATCH] graph based decoding --- .travis.yml | 3 + notebooks/graph_decoding.ipynb | 634 ++++++++++++++++++ notebooks/py/graph_decoding.py | 446 ++++++++++++ .../per_round_max_channel_decoder.py | 11 +- .../test/test_graph_trace_builder.py | 368 ++++++++++ .../core/spots/DecodeSpots/trace_builders.py | 96 ++- starfish/core/spots/DecodeSpots/util.py | 449 ++++++++++++- starfish/core/spots/FindSpots/__init__.py | 1 + starfish/core/spots/FindSpots/h_max.py | 145 ++++ .../FindSpots/test/test_spot_detection.py | 11 + starfish/core/types/_constants.py | 1 + 11 files changed, 2158 insertions(+), 7 deletions(-) create mode 100644 notebooks/graph_decoding.ipynb create mode 100644 notebooks/py/graph_decoding.py create mode 100644 starfish/core/spots/DecodeSpots/test/test_graph_trace_builder.py create mode 100644 starfish/core/spots/FindSpots/h_max.py diff --git a/.travis.yml b/.travis.yml index e508941c6..ce7f87f60 100644 --- a/.travis.yml +++ b/.travis.yml @@ -87,6 +87,9 @@ jobs: - name: SeqFISH Notebook if: type = push and (branch = master or branch =~ /-alltest/) script: make install-dev notebooks/py/SeqFISH.py + - name: Graph-based Decoding Notebook + if: type = push and (branch = master or branch =~ /-alltest/) + script: make install-dev notebooks/py/graph_decoding.py # - name: STARmap Notebook # if: type = push and (branch = master or branch =~ /-alltest/) # script: make install-dev notebooks/py/STARmap.py diff --git a/notebooks/graph_decoding.ipynb b/notebooks/graph_decoding.ipynb new file mode 100644 index 000000000..0ac844bea --- /dev/null +++ b/notebooks/graph_decoding.ipynb @@ -0,0 +1,634 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Graph-based Decoding\n", + "\n", + "This notebook walks through how to use a Graph-based decoding approach [1] to process spatial transcriptomic data. Graph-based decoding can only be applied to assays with one-hot-encoding codebooks (i.e. a single fluorescent channel active per round). We will first see how the graph-based decoding module works with some toy examples, and after we apply it on a real application with in situ sequencing data.\n", + "\n", + "The graph based decoding module ```LocalGraphBlobDetector``` builds a graph out of the candidate spots detected by an arbitrary spot finder algorithm (please see [documentation](https://spacetx-starfish.readthedocs.io/en/stable/) for a list of spot finder algorithms included in the module). Nodes of the graph representing detected spots are then connected with edges based on spatial distances. Cost weights proportional to the distance and quality of the detected spots are then assigned to each edge connecting a pair of nodes. Genes are finally decoded by optimizing the graph with respect to the edge costs providing the best spots configuration with higher qualities and smaller distances.\n", + "\n", + "In details, ```LocalGraphBlobDetector``` first finds spots for every channel and round. Four possible spot detectors are integrated from [scikit-image](https://scikit-image.org/), two based local maxima and two blob detection algorithms. Secondly, overlapping spots are merged across channels within each round in order to handle fluorescent bleed-trough. Next, a quality score is assigned for each detected spot (maximum channel intensity divided by channel intensity vector l2-norm). Detected spots belonging to different sequencing rounds and closer than `search_radius` are connected in a graph, forming connected components of spot detections. Next, for each connected component, edges between not connected spots belonging to consecutive rounds are forced if they are closer than `search_radius_max`. Finally, all the edges that connect spots not belonging to consecutive rounds are removed and each connected component is solved by maximum flow minimum cost algorithm. Where, costs are proportional to spot quality and distances.\n", + "\n", + "[1] Partel, G. et al. Identification of spatial compartments in tissue from in situ sequencing data. BioRxiv, https://doi.org/10.1101/765842, (2019)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import numpy as np\n", + "import os\n", + "import pandas as pd\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "import pprint\n", + "from scipy.ndimage.filters import gaussian_filter\n", + "from starfish.core.spots.FindSpots import BlobDetector, HMax, LocalMaxPeakFinder\n", + "from starfish.core.spots.DecodeSpots import PerRoundMaxChannel\n", + "from starfish.core.spots.DecodeSpots.trace_builders import build_traces_graph_based\n", + "from starfish.types import TraceBuildingStrategies\n", + "\n", + "\n", + "from starfish import data, FieldOfView, ImageStack\n", + "from starfish.types import Axes, Features, FunctionSource\n", + "from starfish.util.plot import imshow_plane" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 1\n", + "We first see an example on how to tune two important parameters, `search_radius` and `search_radius_max`, that define the graph connections of detected spots. We start by creating three synthetic spots laying in two channels and three sequential rounds (color coded with red, green and blue colors). Each of the spot has 3px shift in x,y respect to the spot in the previous round." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create synthetic data\n", + "img = np.zeros((3, 2, 1, 50, 50), dtype=np.float32)\n", + "\n", + "# code 1\n", + "# round 1\n", + "img[0, 0, 0, 35, 35] = 10\n", + "# round 2\n", + "img[1, 1, 0, 32, 32] = 10\n", + "# round 3\n", + "img[2, 0, 0, 29, 29] = 10\n", + "\n", + "# blur points\n", + "gaussian_filter(img, (0, 0, 0, 1.5, 1.5), output=img)\n", + "stack = ImageStack.from_numpy(img)\n", + "\n", + "plt.imshow(np.moveaxis(np.amax(np.squeeze(img),axis=1),0,-1))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now decode the sequence with `LocalGraphBlobDetector` setting `search_radius` to an approximate value representing the euclidean distance between two spots belonging to different sequencing rounds, and `search_radius_max` to a value representing the maximum euclidean distance between all the spots of the same sequence." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# search_radius=5, search_radius_max=10\n", + "hmax_detector = HMax(h=0.5)\n", + "spots = hmax_detector.run(stack)\n", + "intensity_table = build_traces_graph_based(\n", + " spot_results=spots,\n", + " search_radius=5,\n", + " search_radius_max=10,\n", + " anchor_round=0,\n", + " k_d=0.33\n", + ")\n", + "# One sequence decoded\n", + "intensity_table" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# search_radius=5, search_radius_max=5\n", + "spots = hmax_detector.run(stack)\n", + "intensity_table = build_traces_graph_based(\n", + " spot_results=spots,\n", + " search_radius=5,\n", + " search_radius_max=5,\n", + " anchor_round=0,\n", + " k_d=0.33\n", + ")\n", + "# Zero sequence decoded\n", + "intensity_table" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 2\n", + "We now change the distances between the spots such that the 3rd round spot (blue) lay between the other two, and compare the results with other decoding approaches." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create synthetic data\n", + "img = np.zeros((3, 2, 1, 50, 50), dtype=np.float32)\n", + "\n", + "# code 1\n", + "# round 1\n", + "img[0, 0, 0, 35, 35] = 10\n", + "# round 2\n", + "img[1, 1, 0, 29, 29] = 10\n", + "# round 3\n", + "img[2, 0, 0, 32, 32] = 10\n", + "\n", + "# blur points\n", + "gaussian_filter(img, (0, 0, 0, 1.5, 1.5), output=img)\n", + "stack = ImageStack.from_numpy(img)\n", + "\n", + "plt.imshow(np.moveaxis(np.amax(np.squeeze(img),axis=1),0,-1))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# LocalGraphBlobDetector\n", + "blob_detector = BlobDetector(\n", + " detector_method='blob_log',\n", + " min_sigma=(0.4, 1.2, 1.2),\n", + " max_sigma=(0.6, 1.7, 1.7),\n", + " num_sigma=3,\n", + " threshold=0.1,\n", + " overlap=0.5\n", + ")\n", + "spots = blob_detector.run(stack)\n", + "intensity_table = build_traces_graph_based(\n", + " spot_results=spots,\n", + " search_radius=5,\n", + " search_radius_max=10,\n", + " anchor_round=0,\n", + " k_d=0.33\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`LocalSearchBlobDetector` decode the correct sequence only if `anchor_round` is set to the round of the spot lying in the middle, otherwise will fail to connect all the spots." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# LocalSearchBlobDetector\n", + "# Anchor round: 1st round\n", + "spots = blob_detector.run(stack)\n", + "intensity_table = build_traces_graph_based(\n", + " spot_results=spots,\n", + " anchor_round=1,\n", + " search_radius=5,\n", + " search_radius_max=10,\n", + " k_d=0.33\n", + ")\n", + "intensity_table" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# LocalSearchBlobDetector\n", + "# Anchor round: 3rd round\n", + "spots = blob_detector.run(stack)\n", + "intensity_table = build_traces_graph_based(\n", + " spot_results=spots,\n", + " search_radius=5,\n", + " search_radius_max=10,\n", + " anchor_round=2,\n", + " k_d=0.33\n", + ")\n", + "intensity_table" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 3\n", + "Let's now add some noise and multiple possible decoding choices. Specifically, the second round has two possible spot candidates one a bit weaker than the other, both equally spaced respect from the spots of the other rounds. `LocalGraphBlobDetector` provides the best possible decoding solution choosing the spot with highest quality for the second round since distance costs are equivalent. The results of `LocalSearchBlobDetector` strongly depends from the initialization of the anchor round (i.e. from where to start the search), providing different solutions for each initialization.\n", + "(Please note that when anchor round is set to the second round, two sequences are decoded, the correct one plus a false positive.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create synthetic data\n", + "img = np.zeros((3, 2, 1, 50, 50), dtype=np.float32)\n", + "\n", + "# code 1\n", + "# round 1\n", + "img[0, 0, 0, 35, 32] = 1\n", + "# round 2\n", + "img[1, 1, 0, 36, 34] = 1\n", + "img[1, 0, 0, 32, 28] = 0.5 # Noise\n", + "# round 3\n", + "img[2, 0, 0, 33, 30] = 1\n", + "\n", + "# blur points\n", + "gaussian_filter(img, (0, 0, 0, 1.5, 1.5), output=img)\n", + "\n", + "# add camera noise\n", + "np.random.seed(6)\n", + "camera_noise = np.random.normal(scale=0.005, size=img.shape).astype(np.float32)\n", + "img = img + np.clip(camera_noise,0.001,None)\n", + "\n", + "stack = ImageStack.from_numpy(img)\n", + "\n", + "plt.imshow(np.moveaxis(np.amax(np.squeeze(img*10),axis=1),0,-1))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "spots = blob_detector.run(stack)\n", + "intensity_table = build_traces_graph_based(\n", + " spot_results=spots,\n", + " search_radius=5,\n", + " search_radius_max=5,\n", + " anchor_round=0,\n", + " k_d=0.33\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Reproduce In-situ Sequencing results with Starfish Graph-based Decoding\n", + "\n", + "Let's now see the `LocalGraphBlobDetector` applied to In Situ Sequencing (ISS). ISS is an image based transcriptomics technique that can spatially resolve hundreds RNA species and their expression levels in-situ. The protocol and data analysis are described in this [publication](https://www.ncbi.nlm.nih.gov/pubmed/23852452). Here we use the `LocalGraphBlobDetector` to process the raw images from an ISS experiment into a spatially resolved cell by gene expression matrix. And we verify that we can accurately reproduce the results from the authors' original [pipeline](https://cellprofiler.org/previous_examples/#sequencing-rna-molecules-in-situ-combining-cellprofiler-with-imagej-plugins)\n", + "\n", + "Please see [documentation](https://spacetx-starfish.readthedocs.io/en/stable/) for detailed descriptions of all the data structures and methods used here." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "matplotlib.rcParams[\"figure.dpi\"] = 150" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Data into Starfish from the Cloud\n", + "\n", + "The primary data from one field of view correspond to 16 images from 4 hybridization rounds (r) 4 color channels (c) one z plane (z). Each image is 1044 x 1390 (y, x). These data arise from human breast tissue. O(10) transcripts are barcoded for subsequent spatial resolution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# An experiment contains a codebook, primary images, and auxiliary images\n", + "experiment = data.ISS(use_test_data=True)\n", + "pp = pprint.PrettyPrinter(indent=2)\n", + "pp.pprint(experiment._src_doc)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fov = experiment.fov()\n", + "dots = fov.get_image(\"dots\")\n", + "dots_single_plane = dots.reduce((Axes.ROUND, Axes.ZPLANE), func=\"max\", module=FunctionSource.np)\n", + "nuclei = fov.get_image(\"nuclei\")\n", + "nuclei_single_plane = nuclei.reduce((Axes.ROUND, Axes.ZPLANE), func=\"max\", module=FunctionSource.np)\n", + "\n", + "# note the structure of the 5D tensor containing the raw imaging data\n", + "imgs = fov.get_image(FieldOfView.PRIMARY_IMAGES)\n", + "print(imgs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize Codebook\n", + "\n", + "The ISS codebook maps each decoded barcode to a gene.This protocol asserts that genes are encoded with a length 4 quaternary barcode that can be read out from the images. Each round encodes a position in the codeword. The maximum signal in each color channel (columns in the above image) corresponds to a letter in the codeword. The channels, in order, correspond to the letters: 'T', 'G', 'C', 'A'." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "experiment.codebook" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Filter raw data before decoding into spatially resolved gene expression\n", + "\n", + "A White-Tophat filter can be used to enhance spots while minimizing background autoflourescence. The ```masking_radius``` parameter specifies the expected radius, in pixels, of each spot." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from starfish.image import Filter\n", + "\n", + "# filter raw data\n", + "masking_radius = 15\n", + "filt = Filter.WhiteTophat(masking_radius, is_volume=False)\n", + "\n", + "filtered_imgs = filt.run(imgs, verbose=True, in_place=False)\n", + "filt.run(dots, verbose=True, in_place=True)\n", + "filt.run(nuclei, verbose=True, in_place=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Register data\n", + "Images may have shifted between imaging rounds. This needs to be corrected for before decoding, since this shift in the images will corrupt the barcodes, thus hindering decoding accuracy. A simple procedure can correct for this shift. For each imaging round, the max projection across color channels should look like the dots stain. Below, we simply shift all images in each round to match the dots stain by learning the shift that maximizes the cross-correlation between the images and the dots stain." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "from starfish.image import ApplyTransform, LearnTransform\n", + "\n", + "learn_translation = LearnTransform.Translation(reference_stack=dots, axes=Axes.ROUND, upsampling=1000)\n", + "transforms_list = learn_translation.run(imgs.reduce({Axes.CH, Axes.ZPLANE}, func=\"max\"))\n", + "warp = ApplyTransform.Warp()\n", + "registered_imgs = warp.run(filtered_imgs, transforms_list=transforms_list, in_place=False, verbose=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Decode the processed data into spatially resolved gene expression profiles" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```LocalGraphBlobDetector``` instance using [Laplacian of Gaussian (LoG)](https://scikit-image.org/docs/dev/api/skimage.feature.html?highlight=blob_log#skimage.feature.blob_log) blob detection algorithm. Please refer to `scikit-image` documentation for a full parameter list." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "decoded_lgbd = []\n", + "blob_detector = BlobDetector(\n", + " min_sigma=(0, 0.5, 0.5),\n", + " max_sigma=(0, 3, 3),\n", + " num_sigma=10,\n", + " threshold=0.03,\n", + ")\n", + "spots = blob_detector.run(registered_imgs)\n", + "lgbd = PerRoundMaxChannel(\n", + " codebook=experiment.codebook,\n", + " trace_building_strategy=TraceBuildingStrategies.GRAPH,\n", + " anchor_round=0,\n", + " search_radius=3,\n", + " search_radius_max=5\n", + ")\n", + "decoded_intensity_table = lgbd.run(spots=spots)\n", + "decoded_lgbd.append(decoded_intensity_table)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```LocalGraphBlobDetector``` instance using [`peak_local_max`](https://scikit-image.org/docs/dev/api/skimage.feature.html?highlight=peak_local_max#skimage.feature.peak_local_max) local maxima detection algorithm. Please refer to `scikit-image` documentation for a full parameter list." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "local_max_detector = LocalMaxPeakFinder(\n", + " min_distance=6,\n", + " stringency=0,\n", + " threshold=0.03,\n", + " min_obj_area=0,\n", + " max_obj_area=np.inf,\n", + ")\n", + "spots = local_max_detector.run(registered_imgs)\n", + "decoded_intensity_table = lgbd.run(spots=spots)\n", + "decoded_lgbd.append(decoded_intensity_table)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```LocalGraphBlobDetector``` instance using [`h_maxima`](https://scikit-image.org/docs/dev/api/skimage.morphology.html?highlight=h_maxima#skimage.morphology.h_maxima) local maxima detection algorithm. Please refer to `scikit-image` documentation for a full parameter list." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "connectivity=np.array([[[0, 0, 0],[0, 1, 0],[0, 0, 0]],[[0, 1, 0],[1, 1, 1],[0, 1, 0]],[[0, 0, 0],[0, 1, 0],[0, 0, 0]]]) #3D corss\n", + "\n", + "hmax_detector = HMax(h=0.015, selem=connectivity)\n", + "spots = hmax_detector.run(registered_imgs)\n", + "decoded_intensity_table = lgbd.run(spots=spots)\n", + "decoded_lgbd.append(decoded_intensity_table)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now compare the results from the three decoding approaches used previously with `BlobDetector` algorithm. This spot detection finds spots, and record, for each spot, the maximum pixel intensities across rounds and channels." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bd = BlobDetector(\n", + " min_sigma=0.5,\n", + " max_sigma=3,\n", + " num_sigma=10,\n", + " threshold=0.03,\n", + " measurement_type='max',\n", + ")\n", + "# detect spots using laplacian of gaussians approach\n", + "dots_max_projector = Filter.Reduce((Axes.ROUND, Axes.ZPLANE), func=\"max\", module=FunctionSource.np)\n", + "dots_max = dots_max_projector.run(fov.get_image('dots'))\n", + "# locate spots in a reference image\n", + "spots = bd.run(reference_image=dots_max, image_stack=registered_imgs)\n", + "decoder = PerRoundMaxChannel(codebook=experiment.codebook)\n", + "decoded_bd = decoder.run(spots=spots)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To decode the resulting intensity tables, we simply match intensities to codewords in the codebook. This can be done by, for each round, selecting the color channel with the maximum intensity. This forms a potential quaternary code which serves as the key into a lookup in the codebook as to which gene this code corresponds to. Decoded genes are assigned to the target field in the decoded intensity table." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now compare the results of the results from three ```LocalGraphBlobDetector``` instances with respect to `BlobDetector` results, plotting the correlation of decoded read counts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.rc('font', family='serif', size=10)\n", + "for i in range(3):\n", + " decoded_tmp = decoded_lgbd[i]\n", + " genes_lgbd, counts_lgbd = np.unique(decoded_tmp.loc[decoded_tmp[Features.PASSES_THRESHOLDS]][Features.TARGET], return_counts=True)\n", + " result_counts_lgbd = pd.Series(counts_lgbd, index=genes_lgbd)\n", + "\n", + " genes_bd, counts_bd = np.unique(decoded_bd.loc[decoded_bd[Features.PASSES_THRESHOLDS]][Features.TARGET], return_counts=True)\n", + " result_counts_bd = pd.Series(counts_bd, index=genes_bd)\n", + "\n", + " tmp = pd.concat([result_counts_lgbd, result_counts_bd], join='inner', axis=1).values\n", + "\n", + " r = np.corrcoef(tmp[:, 1], tmp[:, 0])[0, 1]\n", + " x = np.linspace(50, 2000)\n", + "\n", + " f = plt.figure()\n", + " ax = plt.subplot(1,2,1)\n", + " ax.scatter(tmp[:, 1], tmp[:, 0], 50, zorder=2)\n", + "\n", + " ax.plot(x, x, '-k', zorder=1)\n", + " plt.xlabel('Gene copy number BlobDetector')\n", + " plt.ylabel('Gene copy number LGBD')\n", + " plt.xscale('log')\n", + " plt.yscale('log')\n", + " plt.title(f'r = {r}');" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compare to results from paper\n", + "This FOV was selected to make sure that we can visualize the tumor/stroma boundary, below this is described by pseudo-coloring HER2 (tumor) and vimentin (VIM, stroma). This distribution matches the one described in the original paper." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from skimage.color import rgb2gray\n", + "\n", + "decoded_lgbd = decoded_lgbd[2]\n", + "\n", + "GENE1 = 'HER2'\n", + "GENE2 = 'VIM'\n", + "\n", + "rgb = np.zeros(registered_imgs.tile_shape + (3,))\n", + "nuclei_mp = nuclei.reduce({Axes.ROUND, Axes.CH, Axes.ZPLANE}, func=\"max\")\n", + "nuclei_numpy = nuclei_mp._squeezed_numpy(Axes.ROUND, Axes.CH, Axes.ZPLANE)\n", + "rgb[:,:,0] = nuclei_numpy\n", + "dots_mp = dots.reduce({Axes.ROUND, Axes.CH, Axes.ZPLANE}, func=\"max\")\n", + "dots_mp_numpy = dots_mp._squeezed_numpy(Axes.ROUND, Axes.CH, Axes.ZPLANE)\n", + "rgb[:,:,1] = dots_mp_numpy\n", + "do = rgb2gray(rgb)\n", + "do = do/(do.max())\n", + "\n", + "plt.imshow(do,cmap='gray')\n", + "plt.axis('off');\n", + "\n", + "with warnings.catch_warnings():\n", + " warnings.simplefilter('ignore', FutureWarning)\n", + " is_gene1 = decoded_lgbd.where(decoded_lgbd[Features.AXIS][Features.TARGET] == GENE1, drop=True)\n", + " is_gene2 = decoded_lgbd.where(decoded_lgbd[Features.AXIS][Features.TARGET] == GENE2, drop=True)\n", + "\n", + "plt.plot(is_gene1.x, is_gene1.y, 'or', markersize=3)\n", + "plt.plot(is_gene2.x, is_gene2.y, 'ob', markersize=3)\n", + "plt.title(f'Red: {GENE1}, Blue: {GENE2}');" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/notebooks/py/graph_decoding.py b/notebooks/py/graph_decoding.py new file mode 100644 index 000000000..db0734c43 --- /dev/null +++ b/notebooks/py/graph_decoding.py @@ -0,0 +1,446 @@ +#!/usr/bin/env python +# coding: utf-8 +# +# EPY: stripped_notebook: {"metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1"}}, "nbformat": 4, "nbformat_minor": 2} + +# EPY: START markdown +## Graph-based Decoding +# +#This notebook walks through how to use a Graph-based decoding approach [1] to process spatial transcriptomic data. Graph-based decoding can only be applied to assays with one-hot-encoding codebooks (i.e. a single fluorescent channel active per round). We will first see how the graph-based decoding module works with some toy examples, and after we apply it on a real application with in situ sequencing data. +# +#The graph based decoding module ```LocalGraphBlobDetector``` builds a graph out of the candidate spots detected by an arbitrary spot finder algorithm (please see [documentation](https://spacetx-starfish.readthedocs.io/en/stable/) for a list of spot finder algorithms included in the module). Nodes of the graph representing detected spots are then connected with edges based on spatial distances. Cost weights proportional to the distance and quality of the detected spots are then assigned to each edge connecting a pair of nodes. Genes are finally decoded by optimizing the graph with respect to the edge costs providing the best spots configuration with higher qualities and smaller distances. +# +#In details, ```LocalGraphBlobDetector``` first finds spots for every channel and round. Four possible spot detectors are integrated from [scikit-image](https://scikit-image.org/), two based local maxima and two blob detection algorithms. Secondly, overlapping spots are merged across channels within each round in order to handle fluorescent bleed-trough. Next, a quality score is assigned for each detected spot (maximum channel intensity divided by channel intensity vector l2-norm). Detected spots belonging to different sequencing rounds and closer than `search_radius` are connected in a graph, forming connected components of spot detections. Next, for each connected component, edges between not connected spots belonging to consecutive rounds are forced if they are closer than `search_radius_max`. Finally, all the edges that connect spots not belonging to consecutive rounds are removed and each connected component is solved by maximum flow minimum cost algorithm. Where, costs are proportional to spot quality and distances. +# +#[1] Partel, G. et al. Identification of spatial compartments in tissue from in situ sequencing data. BioRxiv, https://doi.org/10.1101/765842, (2019). +# EPY: END markdown + +# EPY: START code +# EPY: ESCAPE %matplotlib inline + +import numpy as np +import os +import pandas as pd +import matplotlib +import matplotlib.pyplot as plt +import pprint +from scipy.ndimage.filters import gaussian_filter +from starfish.core.spots.FindSpots import BlobDetector, HMax, LocalMaxPeakFinder +from starfish.core.spots.DecodeSpots import PerRoundMaxChannel +from starfish.core.spots.DecodeSpots.trace_builders import build_traces_graph_based +from starfish.types import TraceBuildingStrategies + + +from starfish import data, FieldOfView, ImageStack +from starfish.types import Axes, Features, FunctionSource +from starfish.util.plot import imshow_plane +# EPY: END code + +# EPY: START markdown +### Example 1 +#We first see an example on how to tune two important parameters, `search_radius` and `search_radius_max`, that define the graph connections of detected spots. We start by creating three synthetic spots laying in two channels and three sequential rounds (color coded with red, green and blue colors). Each of the spot has 3px shift in x,y respect to the spot in the previous round. +# EPY: END markdown + +# EPY: START code +# Create synthetic data +img = np.zeros((3, 2, 1, 50, 50), dtype=np.float32) + +# code 1 +# round 1 +img[0, 0, 0, 35, 35] = 10 +# round 2 +img[1, 1, 0, 32, 32] = 10 +# round 3 +img[2, 0, 0, 29, 29] = 10 + +# blur points +gaussian_filter(img, (0, 0, 0, 1.5, 1.5), output=img) +stack = ImageStack.from_numpy(img) + +plt.imshow(np.moveaxis(np.amax(np.squeeze(img),axis=1),0,-1)) +# EPY: END code + +# EPY: START markdown +#We now decode the sequence with `LocalGraphBlobDetector` setting `search_radius` to an approximate value representing the euclidean distance between two spots belonging to different sequencing rounds, and `search_radius_max` to a value representing the maximum euclidean distance between all the spots of the same sequence. +# EPY: END markdown + +# EPY: START code +# search_radius=5, search_radius_max=10 +hmax_detector = HMax(h=0.5) +spots = hmax_detector.run(stack) +intensity_table = build_traces_graph_based( + spot_results=spots, + search_radius=5, + search_radius_max=10, + anchor_round=0, + k_d=0.33 +) +# One sequence decoded +intensity_table +# EPY: END code + +# EPY: START code +# search_radius=5, search_radius_max=5 +spots = hmax_detector.run(stack) +intensity_table = build_traces_graph_based( + spot_results=spots, + search_radius=5, + search_radius_max=5, + anchor_round=0, + k_d=0.33 +) +# Zero sequence decoded +intensity_table +# EPY: END code + +# EPY: START markdown +### Example 2 +#We now change the distances between the spots such that the 3rd round spot (blue) lay between the other two, and compare the results with other decoding approaches. +# EPY: END markdown + +# EPY: START code +# Create synthetic data +img = np.zeros((3, 2, 1, 50, 50), dtype=np.float32) + +# code 1 +# round 1 +img[0, 0, 0, 35, 35] = 10 +# round 2 +img[1, 1, 0, 29, 29] = 10 +# round 3 +img[2, 0, 0, 32, 32] = 10 + +# blur points +gaussian_filter(img, (0, 0, 0, 1.5, 1.5), output=img) +stack = ImageStack.from_numpy(img) + +plt.imshow(np.moveaxis(np.amax(np.squeeze(img),axis=1),0,-1)) +# EPY: END code + +# EPY: START code +# LocalGraphBlobDetector +blob_detector = BlobDetector( + detector_method='blob_log', + min_sigma=(0.4, 1.2, 1.2), + max_sigma=(0.6, 1.7, 1.7), + num_sigma=3, + threshold=0.1, + overlap=0.5 +) +spots = blob_detector.run(stack) +intensity_table = build_traces_graph_based( + spot_results=spots, + search_radius=5, + search_radius_max=10, + anchor_round=0, + k_d=0.33 +) +# EPY: END code + +# EPY: START markdown +#`LocalSearchBlobDetector` decode the correct sequence only if `anchor_round` is set to the round of the spot lying in the middle, otherwise will fail to connect all the spots. +# EPY: END markdown + +# EPY: START code +# LocalSearchBlobDetector +# Anchor round: 1st round +spots = blob_detector.run(stack) +intensity_table = build_traces_graph_based( + spot_results=spots, + anchor_round=1, + search_radius=5, + search_radius_max=10, + k_d=0.33 +) +intensity_table +# EPY: END code + +# EPY: START code +# LocalSearchBlobDetector +# Anchor round: 3rd round +spots = blob_detector.run(stack) +intensity_table = build_traces_graph_based( + spot_results=spots, + search_radius=5, + search_radius_max=10, + anchor_round=2, + k_d=0.33 +) +intensity_table +# EPY: END code + +# EPY: START markdown +### Example 3 +#Let's now add some noise and multiple possible decoding choices. Specifically, the second round has two possible spot candidates one a bit weaker than the other, both equally spaced respect from the spots of the other rounds. `LocalGraphBlobDetector` provides the best possible decoding solution choosing the spot with highest quality for the second round since distance costs are equivalent. The results of `LocalSearchBlobDetector` strongly depends from the initialization of the anchor round (i.e. from where to start the search), providing different solutions for each initialization. +#(Please note that when anchor round is set to the second round, two sequences are decoded, the correct one plus a false positive.) +# EPY: END markdown + +# EPY: START code +# Create synthetic data +img = np.zeros((3, 2, 1, 50, 50), dtype=np.float32) + +# code 1 +# round 1 +img[0, 0, 0, 35, 32] = 1 +# round 2 +img[1, 1, 0, 36, 34] = 1 +img[1, 0, 0, 32, 28] = 0.5 # Noise +# round 3 +img[2, 0, 0, 33, 30] = 1 + +# blur points +gaussian_filter(img, (0, 0, 0, 1.5, 1.5), output=img) + +# add camera noise +np.random.seed(6) +camera_noise = np.random.normal(scale=0.005, size=img.shape).astype(np.float32) +img = img + np.clip(camera_noise,0.001,None) + +stack = ImageStack.from_numpy(img) + +plt.imshow(np.moveaxis(np.amax(np.squeeze(img*10),axis=1),0,-1)) +# EPY: END code + +# EPY: START code +spots = blob_detector.run(stack) +intensity_table = build_traces_graph_based( + spot_results=spots, + search_radius=5, + search_radius_max=5, + anchor_round=0, + k_d=0.33 +) +# EPY: END code + +# EPY: START markdown +### Reproduce In-situ Sequencing results with Starfish Graph-based Decoding +# +#Let's now see the `LocalGraphBlobDetector` applied to In Situ Sequencing (ISS). ISS is an image based transcriptomics technique that can spatially resolve hundreds RNA species and their expression levels in-situ. The protocol and data analysis are described in this [publication](https://www.ncbi.nlm.nih.gov/pubmed/23852452). Here we use the `LocalGraphBlobDetector` to process the raw images from an ISS experiment into a spatially resolved cell by gene expression matrix. And we verify that we can accurately reproduce the results from the authors' original [pipeline](https://cellprofiler.org/previous_examples/#sequencing-rna-molecules-in-situ-combining-cellprofiler-with-imagej-plugins) +# +#Please see [documentation](https://spacetx-starfish.readthedocs.io/en/stable/) for detailed descriptions of all the data structures and methods used here. +# EPY: END markdown + +# EPY: START code +matplotlib.rcParams["figure.dpi"] = 150 +# EPY: END code + +# EPY: START markdown +### Load Data into Starfish from the Cloud +# +#The primary data from one field of view correspond to 16 images from 4 hybridization rounds (r) 4 color channels (c) one z plane (z). Each image is 1044 x 1390 (y, x). These data arise from human breast tissue. O(10) transcripts are barcoded for subsequent spatial resolution. +# EPY: END markdown + +# EPY: START code + +# An experiment contains a codebook, primary images, and auxiliary images +experiment = data.ISS(use_test_data=True) +pp = pprint.PrettyPrinter(indent=2) +pp.pprint(experiment._src_doc) +# EPY: END code + +# EPY: START code +fov = experiment.fov() +dots = fov.get_image("dots") +dots_single_plane = dots.reduce((Axes.ROUND, Axes.ZPLANE), func="max", module=FunctionSource.np) +nuclei = fov.get_image("nuclei") +nuclei_single_plane = nuclei.reduce((Axes.ROUND, Axes.ZPLANE), func="max", module=FunctionSource.np) + +# note the structure of the 5D tensor containing the raw imaging data +imgs = fov.get_image(FieldOfView.PRIMARY_IMAGES) +print(imgs) +# EPY: END code + +# EPY: START markdown +### Visualize Codebook +# +#The ISS codebook maps each decoded barcode to a gene.This protocol asserts that genes are encoded with a length 4 quaternary barcode that can be read out from the images. Each round encodes a position in the codeword. The maximum signal in each color channel (columns in the above image) corresponds to a letter in the codeword. The channels, in order, correspond to the letters: 'T', 'G', 'C', 'A'. +# EPY: END markdown + +# EPY: START code +experiment.codebook +# EPY: END code + +# EPY: START markdown +### Filter raw data before decoding into spatially resolved gene expression +# +#A White-Tophat filter can be used to enhance spots while minimizing background autoflourescence. The ```masking_radius``` parameter specifies the expected radius, in pixels, of each spot. +# EPY: END markdown + +# EPY: START code +from starfish.image import Filter + +# filter raw data +masking_radius = 15 +filt = Filter.WhiteTophat(masking_radius, is_volume=False) + +filtered_imgs = filt.run(imgs, verbose=True, in_place=False) +filt.run(dots, verbose=True, in_place=True) +filt.run(nuclei, verbose=True, in_place=True) +# EPY: END code + +# EPY: START markdown +### Register data +#Images may have shifted between imaging rounds. This needs to be corrected for before decoding, since this shift in the images will corrupt the barcodes, thus hindering decoding accuracy. A simple procedure can correct for this shift. For each imaging round, the max projection across color channels should look like the dots stain. Below, we simply shift all images in each round to match the dots stain by learning the shift that maximizes the cross-correlation between the images and the dots stain. +# EPY: END markdown + +# EPY: START code + +from starfish.image import ApplyTransform, LearnTransform + +learn_translation = LearnTransform.Translation(reference_stack=dots, axes=Axes.ROUND, upsampling=1000) +transforms_list = learn_translation.run(imgs.reduce({Axes.CH, Axes.ZPLANE}, func="max")) +warp = ApplyTransform.Warp() +registered_imgs = warp.run(filtered_imgs, transforms_list=transforms_list, in_place=False, verbose=True) +# EPY: END code + +# EPY: START markdown +### Decode the processed data into spatially resolved gene expression profiles +# EPY: END markdown + +# EPY: START markdown +#```LocalGraphBlobDetector``` instance using [Laplacian of Gaussian (LoG)](https://scikit-image.org/docs/dev/api/skimage.feature.html?highlight=blob_log#skimage.feature.blob_log) blob detection algorithm. Please refer to `scikit-image` documentation for a full parameter list. +# EPY: END markdown + +# EPY: START code +decoded_lgbd = [] +blob_detector = BlobDetector( + min_sigma=(0, 0.5, 0.5), + max_sigma=(0, 3, 3), + num_sigma=10, + threshold=0.03, +) +spots = blob_detector.run(registered_imgs) +lgbd = PerRoundMaxChannel( + codebook=experiment.codebook, + trace_building_strategy=TraceBuildingStrategies.GRAPH, + anchor_round=0, + search_radius=3, + search_radius_max=5 +) +decoded_intensity_table = lgbd.run(spots=spots) +decoded_lgbd.append(decoded_intensity_table) +# EPY: END code + +# EPY: START markdown +#```LocalGraphBlobDetector``` instance using [`peak_local_max`](https://scikit-image.org/docs/dev/api/skimage.feature.html?highlight=peak_local_max#skimage.feature.peak_local_max) local maxima detection algorithm. Please refer to `scikit-image` documentation for a full parameter list. +# EPY: END markdown + +# EPY: START code +import warnings +local_max_detector = LocalMaxPeakFinder( + min_distance=6, + stringency=0, + threshold=0.03, + min_obj_area=0, + max_obj_area=np.inf, +) +spots = local_max_detector.run(registered_imgs) +decoded_intensity_table = lgbd.run(spots=spots) +decoded_lgbd.append(decoded_intensity_table) +# EPY: END code + +# EPY: START markdown +#```LocalGraphBlobDetector``` instance using [`h_maxima`](https://scikit-image.org/docs/dev/api/skimage.morphology.html?highlight=h_maxima#skimage.morphology.h_maxima) local maxima detection algorithm. Please refer to `scikit-image` documentation for a full parameter list. +# EPY: END markdown + +# EPY: START code +import warnings +connectivity=np.array([[[0, 0, 0],[0, 1, 0],[0, 0, 0]],[[0, 1, 0],[1, 1, 1],[0, 1, 0]],[[0, 0, 0],[0, 1, 0],[0, 0, 0]]]) #3D corss + +hmax_detector = HMax(h=0.015, selem=connectivity) +spots = hmax_detector.run(registered_imgs) +decoded_intensity_table = lgbd.run(spots=spots) +decoded_lgbd.append(decoded_intensity_table) +# EPY: END code + +# EPY: START markdown +#We now compare the results from the three decoding approaches used previously with `BlobDetector` algorithm. This spot detection finds spots, and record, for each spot, the maximum pixel intensities across rounds and channels. +# EPY: END markdown + +# EPY: START code +bd = BlobDetector( + min_sigma=0.5, + max_sigma=3, + num_sigma=10, + threshold=0.03, + measurement_type='max', +) +# detect spots using laplacian of gaussians approach +dots_max_projector = Filter.Reduce((Axes.ROUND, Axes.ZPLANE), func="max", module=FunctionSource.np) +dots_max = dots_max_projector.run(fov.get_image('dots')) +# locate spots in a reference image +spots = bd.run(reference_image=dots_max, image_stack=registered_imgs) +decoder = PerRoundMaxChannel(codebook=experiment.codebook) +decoded_bd = decoder.run(spots=spots) +# EPY: END code + +# EPY: START markdown +#To decode the resulting intensity tables, we simply match intensities to codewords in the codebook. This can be done by, for each round, selecting the color channel with the maximum intensity. This forms a potential quaternary code which serves as the key into a lookup in the codebook as to which gene this code corresponds to. Decoded genes are assigned to the target field in the decoded intensity table. +# EPY: END markdown + +# EPY: START markdown +#We now compare the results of the results from three ```LocalGraphBlobDetector``` instances with respect to `BlobDetector` results, plotting the correlation of decoded read counts. +# EPY: END markdown + +# EPY: START code +plt.rc('font', family='serif', size=10) +for i in range(3): + decoded_tmp = decoded_lgbd[i] + genes_lgbd, counts_lgbd = np.unique(decoded_tmp.loc[decoded_tmp[Features.PASSES_THRESHOLDS]][Features.TARGET], return_counts=True) + result_counts_lgbd = pd.Series(counts_lgbd, index=genes_lgbd) + + genes_bd, counts_bd = np.unique(decoded_bd.loc[decoded_bd[Features.PASSES_THRESHOLDS]][Features.TARGET], return_counts=True) + result_counts_bd = pd.Series(counts_bd, index=genes_bd) + + tmp = pd.concat([result_counts_lgbd, result_counts_bd], join='inner', axis=1).values + + r = np.corrcoef(tmp[:, 1], tmp[:, 0])[0, 1] + x = np.linspace(50, 2000) + + f = plt.figure() + ax = plt.subplot(1,2,1) + ax.scatter(tmp[:, 1], tmp[:, 0], 50, zorder=2) + + ax.plot(x, x, '-k', zorder=1) + plt.xlabel('Gene copy number BlobDetector') + plt.ylabel('Gene copy number LGBD') + plt.xscale('log') + plt.yscale('log') + plt.title(f'r = {r}'); +# EPY: END code + +# EPY: START markdown +### Compare to results from paper +#This FOV was selected to make sure that we can visualize the tumor/stroma boundary, below this is described by pseudo-coloring HER2 (tumor) and vimentin (VIM, stroma). This distribution matches the one described in the original paper. +# EPY: END markdown + +# EPY: START code +from skimage.color import rgb2gray + +decoded_lgbd = decoded_lgbd[2] + +GENE1 = 'HER2' +GENE2 = 'VIM' + +rgb = np.zeros(registered_imgs.tile_shape + (3,)) +nuclei_mp = nuclei.reduce({Axes.ROUND, Axes.CH, Axes.ZPLANE}, func="max") +nuclei_numpy = nuclei_mp._squeezed_numpy(Axes.ROUND, Axes.CH, Axes.ZPLANE) +rgb[:,:,0] = nuclei_numpy +dots_mp = dots.reduce({Axes.ROUND, Axes.CH, Axes.ZPLANE}, func="max") +dots_mp_numpy = dots_mp._squeezed_numpy(Axes.ROUND, Axes.CH, Axes.ZPLANE) +rgb[:,:,1] = dots_mp_numpy +do = rgb2gray(rgb) +do = do/(do.max()) + +plt.imshow(do,cmap='gray') +plt.axis('off'); + +with warnings.catch_warnings(): + warnings.simplefilter('ignore', FutureWarning) + is_gene1 = decoded_lgbd.where(decoded_lgbd[Features.AXIS][Features.TARGET] == GENE1, drop=True) + is_gene2 = decoded_lgbd.where(decoded_lgbd[Features.AXIS][Features.TARGET] == GENE2, drop=True) + +plt.plot(is_gene1.x, is_gene1.y, 'or', markersize=3) +plt.plot(is_gene2.x, is_gene2.y, 'ob', markersize=3) +plt.title(f'Red: {GENE1}, Blue: {GENE2}'); +# EPY: END code diff --git a/starfish/core/spots/DecodeSpots/per_round_max_channel_decoder.py b/starfish/core/spots/DecodeSpots/per_round_max_channel_decoder.py index f8060200d..fe689cd2b 100644 --- a/starfish/core/spots/DecodeSpots/per_round_max_channel_decoder.py +++ b/starfish/core/spots/DecodeSpots/per_round_max_channel_decoder.py @@ -39,11 +39,16 @@ def __init__( codebook: Codebook, trace_building_strategy: TraceBuildingStrategies = TraceBuildingStrategies.EXACT_MATCH, anchor_round: Optional[int]=1, - search_radius: Optional[int]=3): + search_radius: Optional[int]=3, + search_radius_max: Optional[int] = 5, + k_d: Optional[float]=0.33, + ): self.codebook = codebook self.trace_builder: Callable = TRACE_BUILDERS[trace_building_strategy] self.anchor_round = anchor_round self.search_radius = search_radius + self.search_radius_max = search_radius_max + self.k_d = k_d def run(self, spots: SpotFindingResults, *args) -> DecodedIntensityTable: """Decode spots by selecting the max-valued channel in each sequencing round @@ -61,6 +66,8 @@ def run(self, spots: SpotFindingResults, *args) -> DecodedIntensityTable: """ intensities = self.trace_builder(spot_results=spots, anchor_round=self.anchor_round, - search_radius=self.search_radius) + search_radius=self.search_radius, + search_radius_max=self.search_radius_max, + k_d=self.k_d) transfer_physical_coords_to_intensity_table(intensity_table=intensities, spots=spots) return self.codebook.decode_per_round_max(intensities) diff --git a/starfish/core/spots/DecodeSpots/test/test_graph_trace_builder.py b/starfish/core/spots/DecodeSpots/test/test_graph_trace_builder.py new file mode 100644 index 000000000..98e6df4fa --- /dev/null +++ b/starfish/core/spots/DecodeSpots/test/test_graph_trace_builder.py @@ -0,0 +1,368 @@ +import numpy as np +from scipy.ndimage.filters import gaussian_filter + +from starfish import ImageStack +from starfish.core.spots.DecodeSpots.trace_builders import build_traces_graph_based +from starfish.core.spots.FindSpots import HMax +from starfish.core.types import Axes + + +def traversing_code() -> ImageStack: + """this code walks in a sequential direction""" + img = np.zeros((3, 2, 20, 50, 50), dtype=np.float32) + + # code 1 + img[0, 0, 5, 35, 35] = 10 + img[1, 1, 5, 32, 32] = 10 + img[2, 0, 5, 29, 29] = 10 + + # blur points + gaussian_filter(img, (0, 0, 0.5, 1.5, 1.5), output=img) + + return ImageStack.from_numpy(img) + + +def empty_data() -> ImageStack: + """this code walks in a sequential direction""" + img = np.zeros((3, 2, 20, 50, 50), dtype=np.float32) + + return ImageStack.from_numpy(img) + + +def multiple_possible_neighbors() -> ImageStack: + """last round has more spots""" + img = np.zeros((3, 2, 20, 50, 50), dtype=np.float32) + + # round 1 + img[0, 0, 5, 20, 40] = 10 + img[0, 0, 5, 40, 20] = 10 + + # round 2 + img[1, 1, 5, 20, 40] = 10 + img[1, 1, 5, 40, 20] = 10 + + # round 3 + img[2, 0, 5, 20, 40] = 10 + img[2, 0, 5, 35, 35] = 10 + img[2, 0, 5, 40, 20] = 10 + + # blur points + gaussian_filter(img, (0, 0, 0.5, 1.5, 1.5), output=img) + + return ImageStack.from_numpy(img) + + +def multiple_possible_neighbors_with_jitter() -> ImageStack: + """last round has more spots and spots have some jitter <= 10px""" + img = np.zeros((3, 2, 20, 50, 50), dtype=np.float32) + + # round 1 + img[0, 0, 5, 20, 40] = 10 + img[0, 0, 5, 40, 10] = 10 + + # round 2 + img[1, 1, 5, 20, 45] = 10 + img[1, 1, 5, 40, 30] = 10 + + # round 3 + img[2, 0, 5, 20, 40] = 10 + img[2, 0, 5, 40, 20] = 10 + + # blur points + gaussian_filter(img, (0, 0, 0.5, 1.5, 1.5), output=img) + + return ImageStack.from_numpy(img) + + +def multiple_possible_neighbors_with_jitter_with_noise() -> ImageStack: + """last round has more spots and spots have some jitter <= 10px""" + img = np.zeros((3, 2, 20, 50, 50), dtype=np.float32) + + # round 1 + img[0, 0, 5, 20, 40] = 10 + img[0, 1, 5, 40, 20] = 10 + + # round 2 + img[1, 1, 5, 20, 45] = 10 + img[1, 1, 5, 30, 30] = 10 + + # round 3 + img[2, 0, 5, 20, 40] = 10 + img[2, 0, 5, 30, 20] = 10 + img[2, 1, 5, 40, 30] = 1 + + # blur points + gaussian_filter(img, (0, 0, 0.5, 1.5, 1.5), output=img) + + return ImageStack.from_numpy(img) + + +def channels_crosstalk() -> ImageStack: + """this code has spots with intensity crosstalk between channels of the same round""" + img = np.zeros((3, 2, 20, 50, 50), dtype=np.float32) + + # round 1 + img[0, 0, 5, 20, 40] = 10 + img[0, 1, 5, 20, 40] = 5 + + # round 2 + img[1, 0, 4, 20, 40] = 5 + img[1, 1, 5, 20, 40] = 10 + + # round 3 + img[2, 0, 5, 20, 40] = 10 + + # blur points + gaussian_filter(img, (0, 0, 0.5, 1.5, 1.5), output=img) + + return ImageStack.from_numpy(img) + + +def jitter_code() -> ImageStack: + """this code has some minor jitter <= 3px at the most distant point""" + img = np.zeros((3, 2, 20, 50, 50), dtype=np.float32) + + # code 1 + img[0, 0, 5, 35, 35] = 10 + img[1, 1, 5, 34, 35] = 10 + img[2, 0, 6, 35, 33] = 10 + + # blur points + gaussian_filter(img, (0, 0, 0.5, 1.5, 1.5), output=img) + + return ImageStack.from_numpy(img) + + +def two_perfect_codes() -> ImageStack: + """this code has no jitter""" + img = np.zeros((3, 2, 20, 50, 50), dtype=np.float32) + + # code 1 + img[0, 0, 5, 20, 35] = 10 + img[1, 1, 5, 20, 35] = 10 + img[2, 0, 5, 20, 35] = 10 + + # code 1 + img[0, 0, 5, 40, 45] = 10 + img[1, 1, 5, 40, 45] = 10 + img[2, 0, 5, 40, 45] = 10 + + # blur points + gaussian_filter(img, (0, 0, 0.5, 1.5, 1.5), output=img) + + return ImageStack.from_numpy(img) + + +def simple_hmax_detector(): + return HMax(h=0.5) + + +h_max_detector = simple_hmax_detector() + + +def test_local_graph_blob_detector_empty_data(): + stack = empty_data() + spots = h_max_detector.run(image_stack=stack) + intensity_table = build_traces_graph_based( + spot_results=spots, + k_d=0.33, + search_radius=1, + search_radius_max=1, + anchor_round=0) + + assert intensity_table.shape == (0, 3, 2) + f, c, r = np.nonzero(intensity_table.values) + assert np.all(f == 0) + assert np.all(c == 0) + assert np.all(r == 0) + + +def test_local_graph_blob_detector_two_codes(): + stack = two_perfect_codes() + # Find spots with 'h-maxima' + spots = h_max_detector.run(image_stack=stack) + intensity_table = build_traces_graph_based( + spot_results=spots, + k_d=0.33, + search_radius=1, + search_radius_max=1, + anchor_round=0) + + assert intensity_table.shape == (2, 2, 3) + assert np.all(intensity_table[0][Axes.X.value] == 35) + assert np.all(intensity_table[0][Axes.Y.value] == 20) + assert np.all(intensity_table[0][Axes.ZPLANE.value] == 5) + + +def test_local_graph_blob_detector_jitter_code(): + stack = jitter_code() + spots = h_max_detector.run(image_stack=stack) + + intensity_table = build_traces_graph_based( + spot_results=spots, + k_d=0.33, + search_radius=3, + search_radius_max=3, + anchor_round=0) + + assert intensity_table.shape == (1, 2, 3) + f, c, r = np.where(~intensity_table.isnull()) + assert np.all(f == np.array([0, 0, 0])) + assert np.all(c == np.array([0, 0, 1])) + assert np.all(r == np.array([0, 2, 1])) + + # test again with smaller search radius + spots = h_max_detector.run(image_stack=stack) + + intensity_table = build_traces_graph_based( + spot_results=spots, + k_d=0.33, + search_radius=1, + search_radius_max=5, + anchor_round=0) + + assert intensity_table.shape == (0, 2, 3) + f, c, r = np.where(~intensity_table.isnull()) + assert np.all(f == 0) + assert np.all(c == 0) + assert np.all(r == 0) + + # test again with smaller search radius max + spots = h_max_detector.run(image_stack=stack) + + intensity_table = build_traces_graph_based( + spot_results=spots, + k_d=0.33, + search_radius=3, + search_radius_max=1, + anchor_round=0) + + assert intensity_table.shape == (0, 2, 3) + f, c, r = np.nonzero(intensity_table.values) + assert np.all(f == 0) + assert np.all(c == 0) + assert np.all(r == 0) + + +def test_local_graph_blob_detector_traversing_code(): + stack = traversing_code() + spots = h_max_detector.run(image_stack=stack) + + intensity_table = build_traces_graph_based( + spot_results=spots, + k_d=0.33, + search_radius=5, + search_radius_max=10, + anchor_round=0) + + assert intensity_table.shape == (1, 2, 3) + f, c, r = np.where(~intensity_table.isnull()) + assert np.all(f == np.array([0, 0, 0])) + assert np.all(c == np.array([0, 0, 1])) + assert np.all(r == np.array([0, 2, 1])) + + spots = h_max_detector.run(image_stack=stack) + + intensity_table = build_traces_graph_based( + spot_results=spots, + k_d=0.33, + search_radius=5, + search_radius_max=5, + anchor_round=0) + + f, c, r = np.where(~intensity_table.isnull()) + assert np.all(f == 0) + assert np.all(c == 0) + assert np.all(r == 0) + + +def test_local_graph_blob_detector_multiple_neighbors(): + stack = multiple_possible_neighbors() + spots = h_max_detector.run(image_stack=stack) + intensity_table = build_traces_graph_based( + spot_results=spots, + k_d=0.33, + search_radius=4, + search_radius_max=4, + anchor_round=0) + + assert intensity_table.shape == (2, 2, 3) + assert np.all(intensity_table[Axes.ZPLANE.value] == (5, 5)) + assert np.all(intensity_table[Axes.Y.value] == (20, 40)) + assert np.all(intensity_table[Axes.X.value] == (40, 20)) + + spots = h_max_detector.run(image_stack=stack) + intensity_table = build_traces_graph_based( + spot_results=spots, + k_d=0.33, + search_radius=15, + search_radius_max=20, + anchor_round=0) + + assert intensity_table.shape == (2, 2, 3) + assert np.all(intensity_table[Axes.ZPLANE.value] == (5, 5)) + assert np.all(intensity_table[Axes.Y.value] == (20, 40)) + assert np.all(intensity_table[Axes.X.value] == (40, 20)) + + +def test_local_graph_blob_detector_multiple_neighbors_with_jitter(): + stack = multiple_possible_neighbors_with_jitter() + spots = h_max_detector.run(image_stack=stack) + intensity_table = build_traces_graph_based( + spot_results=spots, + k_d=0.33, + search_radius=10, + search_radius_max=20, + anchor_round=0) + + assert intensity_table.shape == (2, 2, 3) + assert np.all(intensity_table[Axes.ZPLANE.value] == (5, 5)) + assert np.all(intensity_table[Axes.Y.value] == (20, 40)) + assert np.all(intensity_table[Axes.X.value] == (40, 10)) + + spots = h_max_detector.run(image_stack=stack) + intensity_table = build_traces_graph_based( + spot_results=spots, + k_d=0.33, + search_radius=15, + search_radius_max=15, + anchor_round=0) + + assert intensity_table.shape == (1, 2, 3) + assert np.all(intensity_table[Axes.ZPLANE.value] == (5)) + assert np.all(intensity_table[Axes.Y.value] == (20)) + assert np.all(intensity_table[Axes.X.value] == (40)) + + +def test_local_graph_blob_detector_multiple_neighbors_with_jitter_with_noise(): + stack = multiple_possible_neighbors_with_jitter_with_noise() + spots = h_max_detector.run(image_stack=stack) + intensity_table = build_traces_graph_based( + spot_results=spots, + k_d=0.33, + search_radius=10, + search_radius_max=20, + anchor_round=0) + + assert intensity_table.shape == (2, 2, 3) + f, c, r = np.where(~intensity_table.isnull()) + assert np.all(f == np.array([0, 0, 0, 1, 1, 1])) + assert np.all(c == np.array([0, 0, 1, 0, 1, 1])) + assert np.all(r == np.array([0, 2, 1, 2, 0, 1])) + + +def test_local_graph_blob_detector_channels_crosstalk(): + stack = channels_crosstalk() + spots = h_max_detector.run(image_stack=stack) + intensity_table = build_traces_graph_based( + spot_results=spots, + k_d=0.33, + search_radius=3, + search_radius_max=5, + anchor_round=0) + + assert intensity_table.shape == (1, 2, 3) + f, c, r = np.where(~intensity_table.isnull()) + assert np.all(f == np.array([0, 0, 0, 0, 0])) + assert np.all(c == np.array([0, 0, 0, 1, 1])) + assert np.all(r == np.array([0, 1, 2, 0, 1])) diff --git a/starfish/core/spots/DecodeSpots/trace_builders.py b/starfish/core/spots/DecodeSpots/trace_builders.py index b54ff373d..52d7cce4e 100644 --- a/starfish/core/spots/DecodeSpots/trace_builders.py +++ b/starfish/core/spots/DecodeSpots/trace_builders.py @@ -1,11 +1,19 @@ from typing import Callable, Mapping +import numpy as np import pandas as pd from starfish.core.intensity_table.intensity_table import IntensityTable from starfish.core.types import Axes, Features, SpotAttributes, SpotFindingResults, \ TraceBuildingStrategies -from .util import _build_intensity_table, _match_spots, _merge_spots_by_round +from .util import ( + _build_intensity_table_graph_results, + _build_intensity_table_nearest_neighbor_results, + _build_spot_traces_per_round, + _compute_spot_trace_qualities, + _match_spots, + _merge_spots_by_round +) def build_spot_traces_exact_match(spot_results: SpotFindingResults, **kwargs) -> IntensityTable: @@ -67,8 +75,12 @@ def build_traces_sequential(spot_results: SpotFindingResults, **kwargs) -> Inten return intensity_table -def build_traces_nearest_neighbors(spot_results: SpotFindingResults, anchor_round: int=1, - search_radius: int=3): +def build_traces_nearest_neighbors( + spot_results: SpotFindingResults, + anchor_round: int=1, + search_radius: int=3, + **kwargs +): """ Combine spots found across round and channels of an ImageStack using a nearest neighbors strategy @@ -90,7 +102,7 @@ def build_traces_nearest_neighbors(spot_results: SpotFindingResults, anchor_roun per_round_spot_results, anchor_round=anchor_round ) - intensity_table = _build_intensity_table( + intensity_table = _build_intensity_table_nearest_neighbor_results( per_round_spot_results, distances, indices, rounds=spot_results.round_labels, channels=spot_results.ch_labels, @@ -100,8 +112,84 @@ def build_traces_nearest_neighbors(spot_results: SpotFindingResults, anchor_roun return intensity_table +def build_traces_graph_based( + spot_results: SpotFindingResults, + k_d: float, + search_radius: int, + search_radius_max: int, + anchor_round: int, + **kwargs +): + """ + Overlapping spots are merged across channels within each round in order to handle fluorescent + bleed-trough. Next, a quality score is assigned for each detected spot (maximum intensity + divided by intensity vector l2-norm). Detected spots belonging to different sequencing rounds + and closer than d_th are connected in a graph, forming connected components of spot detections. + Next, for each connected component, edges between not connected spots belonging to consecutive + rounds are forced if they are closer than dth_max. Finally, all the edges that connect spots + non belonging to consecutive rounds are removed and each connected component is solved by + maximum flow minimum cost algorithm. Costs are inversely proportional to spot quality and + distances. The final intensity table is then initialized with the intensity table of the + round chosen as anchor (default: first round). + + + Parameters + ---------- + spot_results : SpotFindingResults + Spots found across rounds/channels of an ImageStack + anchor_round : int + The imaging round against which other rounds will be checked for spots in the same + approximate pixel location. + search_radius : int + Euclidean distance in pixels over which to search for spots in subsequent rounds. + search_radius_max : int + The maximum (euclidian) distance in pixels allowed between nodes belonging to the + same sequence + k_d : float + distance weight + + Notes + ----- + [1] Partel, G. et al. Identification of spatial compartments in tissue from in situ sequencing + data. BioRxiv, https://doi.org/10.1101/765842, (2019). + """ + if spot_results.count_total_spots() == 0: + spot_attributes = list(spot_results.values())[0] + return IntensityTable.zeros( + spot_attributes=spot_attributes, + round_labels=spot_results.round_labels, + ch_labels=spot_results.ch_labels, + ) + else: + round_dataframes = _merge_spots_by_round(spot_results) + + spot_traces = _build_spot_traces_per_round( + round_dataframes, + channels=spot_results.ch_labels, + rounds=spot_results.round_labels) + + spot_traces = _compute_spot_trace_qualities(spot_traces) + + intensity_table = _build_intensity_table_graph_results( + intensity_tables=spot_traces, + rounds=spot_results.round_labels, + search_radius=search_radius, + search_radius_max=search_radius_max, + k_d=k_d, + anchor_round=anchor_round) + + # Drop intensities with empty rounds + drop = [np.any(np.all(np.isnan(intensity_table.values[x, :, :]), axis=0)) + for x in range(intensity_table.shape[0])] + intensity_table = IntensityTable( + intensity_table[np.arange(intensity_table.shape[0])[np.invert(drop)]]) + + return intensity_table + + TRACE_BUILDERS: Mapping[TraceBuildingStrategies, Callable] = { TraceBuildingStrategies.EXACT_MATCH: build_spot_traces_exact_match, TraceBuildingStrategies.NEAREST_NEIGHBOR: build_traces_nearest_neighbors, TraceBuildingStrategies.SEQUENTIAL: build_traces_sequential, + TraceBuildingStrategies.GRAPH: build_traces_graph_based } diff --git a/starfish/core/spots/DecodeSpots/util.py b/starfish/core/spots/DecodeSpots/util.py index be64c672c..7c414ab68 100644 --- a/starfish/core/spots/DecodeSpots/util.py +++ b/starfish/core/spots/DecodeSpots/util.py @@ -1,9 +1,13 @@ from collections import defaultdict +from itertools import combinations from typing import Any, Dict, Hashable, List, Mapping, Sequence, Tuple +import networkx as nx import numpy as np import pandas as pd +from scipy.spatial import cKDTree as KDTree from sklearn.neighbors import NearestNeighbors +from tqdm import tqdm from starfish.core.intensity_table.intensity_table import IntensityTable from starfish.core.types import Axes, Features, SpotFindingResults @@ -56,7 +60,7 @@ def _match_spots( return dist, ind -def _build_intensity_table( +def _build_intensity_table_nearest_neighbor_results( round_dataframes: Dict[int, pd.DataFrame], dist: pd.DataFrame, ind: pd.DataFrame, @@ -122,6 +126,449 @@ def _build_intensity_table( return intensity_table +def _build_spot_traces_per_round( + round_dataframes: Dict[int, pd.DataFrame], + channels: Sequence[int], + rounds: Sequence[int] +) -> Dict[int, IntensityTable]: + """ For each round, find connected components of spots across channels and merge them + in a single spot trace. + + Parameters + ---------- + round_dataframes : Dict[int, pd.DataFrame] + Output from _merge_spots_by_round, contains mapping of image volumes from each round to + all the spots detected in them. + channels, rounds : Sequence[int] + Channels and rounds present in the ImageStack from which spots were detected. + Returns + ------- + Dict[int, IntensityTable] + Dictionary mapping round to the relative IntensityTable. + """ + intensity_tables = {} + + # get spots matching across channels + for r, df in round_dataframes.items(): + # Find connected components across channels + G = nx.Graph() + kdT = KDTree(df.loc[:, [Axes.ZPLANE.value, Axes.Y.value, Axes.X.value]].values) + pairs = kdT.query_pairs(1, p=1) + G.add_nodes_from(df.index.values) + G.add_edges_from(pairs) + conn_comps = [list(i) for i in nx.connected_components(G)] + # for each connected component keep detection with highest intensity + refined_conn_comps = [] + for i in range(len(conn_comps)): + df_tmp = df.loc[conn_comps[i], :] + kdT_tmp = KDTree(df_tmp.loc[:, [Axes.ZPLANE.value, Axes.Y.value, Axes.X.value]].values) + # Check if all spots whitin a conn component are at most 1 pixels away + # from each other (Manhattan distance) + spot_pairs = len(list(combinations(np.arange(len(df_tmp)), 2))) + spots_connected = len(kdT_tmp.query_pairs(2, p=1)) # 2 could be a parameter + if spot_pairs == spots_connected: + # Merge spots + refined_conn_comps.append(conn_comps[i]) + else: + # split non overlapping signals + for j, row in df_tmp.drop_duplicates([Axes.ZPLANE.value, Axes.Y.value, + Axes.X.value]).iterrows(): + refined_conn_comps.append(df_tmp[(df_tmp.z == row.z) & (df_tmp.y == row.y) + & (df_tmp.x == row.x)].index.values.tolist()) + + data = np.full((len(refined_conn_comps), len(channels), len(rounds)), + fill_value=np.nan) + spot_radius = [] + z = [] + y = [] + x = [] + for f_idx, s in enumerate(refined_conn_comps): + df_tmp = df.loc[s] + anchor_s_idx = df_tmp.intensity.idxmax() + z.append(df_tmp.loc[anchor_s_idx, Axes.ZPLANE.value]) + y.append(df_tmp.loc[anchor_s_idx, Axes.Y.value]) + x.append(df_tmp.loc[anchor_s_idx, Axes.X.value]) + spot_radius.append(df_tmp.loc[anchor_s_idx, Features.SPOT_RADIUS]) + for i, row in df_tmp.iterrows(): + data[f_idx, int(row.c), r] = row.intensity + # # create IntensityTable + dims = (Features.AXIS, Axes.CH.value, Axes.ROUND.value) + coords: Mapping[Hashable, Tuple[str, Any]] = { + Features.SPOT_RADIUS: (Features.AXIS, spot_radius), + Axes.ZPLANE.value: (Features.AXIS, z), + Axes.Y.value: (Features.AXIS, y), + Axes.X.value: (Features.AXIS, x), + Axes.ROUND.value: (Axes.ROUND.value, rounds), + Axes.CH.value: (Axes.CH.value, channels) + } + intensity_table = IntensityTable(data=data, dims=dims, coords=coords) + intensity_tables[r] = intensity_table + + return intensity_tables + + +def _baseCalling(data: list, rounds: Sequence[int], search_radius_max: int) -> np.ndarray: + """Extract intensity table feature indeces and channels from each connected component graph + + Parameters + ---------- + data : list + Output from _runGraphDecoder, contains decoded spots + rounds : Sequence[int] + Rounds present in the ImageStack from which spots were detected + search_radius_max : int + The maximum (euclidian) distance in pixels allowed between nodes belonging + to the same sequence + + Returns + ------- + np.ndarray + feature indeces arrays of _merge_spots_by_round output intensity tables ordered by round + """ + idx = [] + if data: + for graph in tqdm(data): + G = graph['G'] + Dvar = graph['Dvar'] + for c in nx.connected_components(G): + c = np.array(list(c)) + c = c[c <= Dvar.X_idx.max()] + Dvar_c = Dvar[(Dvar.X_idx.isin(c))] + if len(Dvar_c) == len(rounds): + k1 = KDTree(Dvar_c[[Axes.X.value, Axes.Y.value, Axes.ZPLANE.value]].values) + max_d = np.amax(list(k1.sparse_distance_matrix(k1, np.inf).values())) + if max_d <= search_radius_max: + idx.append(Dvar[ + (Dvar.X_idx.isin(c))].sort_values(['r']).feature_id.values) + return np.array(idx).astype(np.uint) + + +def _compute_spot_trace_qualities(intensity_tables: Dict[int, IntensityTable] + ) -> Dict[int, IntensityTable]: + """Interate over the intesity tables of each round and assign to each feature a quality score + Parameters + ---------- + + Returns + ------- + Dict[int,IntensityTable]: + Dictionary mapping round to the relative IntensityTable with quality coordinate Q + representing the quality score of each feature. + """ + for i in intensity_tables: + intensity_tables[i]['Q'] = (Features.AXIS, + np.divide(np.amax(intensity_tables[i].fillna(0).values, + axis=1), + np.linalg.norm( + intensity_tables[i].fillna(0).values, + 2, axis=1), + where=np.linalg.norm( + intensity_tables[i].fillna(0).values, + 2, axis=1) != 0)[:, i]) + return intensity_tables + + +def _runGraphBuilder(data: pd.DataFrame, + d_th: float, + k_d: float, + dth_max: float) -> list: + """Find connected components of detected spots across rounds and call the graph + decoder for each connected component instance. + + Parameters + ---------- + data : pd.DataFrame + Dataframe of detected spots with probability values, with columns + [x, y, z, r, c, idx, p0, p1, feature_id] + d_th : flaot + maximum distance inside connected component between two connected spots + k_d : float + distance weight + dth_max : float + maximum distance inside connected component between every pair of spots + + Returns + ------- + list[Dict[str,Any]] + List of dictionaries output of _runMaxFlowMinCost + """ + print("Generate Graph Model...\n") + num_hyb = np.arange(0, int(np.amax(data.r)) + 1) + data.sort_values('r', inplace=True) + data = data.reset_index(drop=True) + # Graphical Model Data Structures + # Generate connected components + G = nx.Graph() + G.add_nodes_from(data.index.values) + for h1 in tqdm(num_hyb): + KDTree_h1 = KDTree(data[data.r == h1][[Axes.X.value, Axes.Y.value, Axes.ZPLANE.value]]) + for h2 in num_hyb[h1:]: + if h1 != h2: + KDTree_h2 = KDTree(data[data.r == h2][[Axes.X.value, Axes.Y.value, + Axes.ZPLANE.value]]) + query = KDTree_h1.query_ball_tree(KDTree_h2, d_th, p=2) + E = [] + offset1 = data.index[data.r == h1].min() + offset2 = data.index[data.r == h2].min() + for i, e1 in enumerate(query): + if e1: + for e2 in e1: + E.append((i + offset1, e2 + offset2)) + G.add_edges_from(E) + + conn_comps = [list(c) for c in nx.connected_components(G)] + for c in tqdm(range(len(conn_comps))): + data.loc[conn_comps[c], 'conn_comp'] = c + + # Drop conn components with less than n_hybs elements + gr = data.groupby('conn_comp') + for i, group in gr: + if len(group) < len(num_hyb): + data = data.drop(group.index) + labels = np.unique(data.conn_comp) + + if labels.size > 0: + print("Run Graph Model...\n") + res = [] + for l in tqdm(np.nditer(labels), total=len(labels)): + res.append(_runMaxFlowMinCost(data, int(l), d_th, k_d, num_hyb, dth_max)) + # return maxFlowMinCost + return [x for x in res if x['G'] is not None] + else: + return [] + + +def _prob2Eng(p: float) -> float: + """Convert probability values into energy by inverse Gibbs distribution + + Parameters + ---------- + p : float + probability value + + Returns + ------- + float + energy value + """ + return -1.0 * np.log(np.clip(p, 0.00001, 0.99999)) + + +def _runMaxFlowMinCost( + data: pd.DataFrame, + l: int, + d_th: float, + k_d: float, + rounds: np.array, + dth_max: float) -> Dict: + """Build the graph model for the given connected component and solve the graph + with max flow min cost alghorithm + + Parameters + ---------- + data : pd.DataFrame + Dataframe of detected spots with probability values, with columns + [x, y, z, r, c, idx, p0, p1] + l : int + connected component index + d_th : float + maximum distance inside connected component between two connected spots + k_d : float + distance weight + rounds : np.array[int] + Channels and rounds present in the ImageStack from which spots were detected. + dth_max : float + maximum distance inside connected component between every pair of spots + + Returns + ------- + Dict[str, Any] + Dictionary mapping the decoded graph, Dataframe of detected spots with + probability values, Dataframe of connected spots + """ + + if len(data[data.conn_comp == l]): + if len(np.unique(data[data.conn_comp == l].r)) == len(rounds): + data_tmp = data[data.conn_comp == l].sort_values(['r']).copy() + data_tmp.reset_index(inplace=True, drop=True) + Dvar_tmp = data_tmp.loc[:, [Axes.X.value, Axes.Y.value, Axes.ZPLANE.value, + Axes.ROUND.value, Axes.CH.value, 'feature_id']] + Dvar_tmp['E_0'] = data_tmp.p0.apply(_prob2Eng) + Dvar_tmp['E_1'] = data_tmp.p1.apply(_prob2Eng) + Dvar_tmp['X_idx'] = data_tmp.index.values + + X_idx_tmp = len(Dvar_tmp) + Tvar_tmp = pd.DataFrame( + data={'x_idx': [], + 'anchestor_x_idx': [], + 'descendant_x_idx': [], + 'E_0': [], + 'E_1': [], + 'mu_d': []}) + for h1 in rounds[:-1]: + h2 = h1 + 1 + + df1 = data_tmp[data_tmp.r == h1] + df2 = data_tmp[data_tmp.r == h2] + df1_coords = df1[[Axes.X.value, Axes.Y.value, Axes.ZPLANE.value]].values + df2_coords = df2[[Axes.X.value, Axes.Y.value, Axes.ZPLANE.value]].values + + KDTree_h1 = KDTree(df1_coords) + KDTree_h2 = KDTree(df2_coords) + query = KDTree_h1.query_ball_tree(KDTree_h2, dth_max, p=2) + for i in range(len(query)): + if len(query[i]): + X_idx = [(X_idx_tmp + x) for x in range(len(query[i]))] + d = [np.linalg.norm(df1_coords[i] - df2_coords[x]) for x in query[i]] + mu_d = [1 / (1 + k_d * x) for x in d] + + Tvar_tmp = Tvar_tmp.append( + pd.DataFrame(data={ + 'x_idx': X_idx, + 'anchestor_x_idx': np.ones(len(query[i])) * df1.index[i], + 'descendant_x_idx': df2.index[query[i]].values, + 'E_0': [_prob2Eng(1 - x) for x in mu_d], + 'E_1': [_prob2Eng(x) for x in mu_d], + 'mu_d': mu_d})) + X_idx_tmp = X_idx[-1] + 1 + + Dvar_tmp.X_idx = Dvar_tmp.X_idx + 1 + Tvar_tmp.anchestor_x_idx = Tvar_tmp.anchestor_x_idx + 1 + Tvar_tmp.descendant_x_idx = Tvar_tmp.descendant_x_idx + 1 + + Dvar_tmp['X'] = np.arange(1, len(Dvar_tmp) + 1) + + sink = Dvar_tmp.X.max() + 1 + + # Inizialize graph + G = nx.DiGraph() + + E = [] # Edges + n = sink + 1 + + for h in rounds: + for idx, row in Dvar_tmp[(Dvar_tmp.r == h)].iterrows(): + if h == 0: + E.append((0, row.X, {'capacity': 1, 'weight': 0})) + # Add detection edges + E.append((row.X, n, { + 'capacity': 1, + 'weight': np.round(row.E_1 * 1000000).astype(int)})) + n = n + 1 + + G.add_edges_from(E) + E = [] + for idx, row in Tvar_tmp[( + Tvar_tmp.anchestor_x_idx.isin( + Dvar_tmp[Dvar_tmp.r == h].X_idx))].iterrows(): + # Add transition edges + E.append((list(G.successors(row.anchestor_x_idx))[0], row.descendant_x_idx, + {'capacity': 1, + 'weight': np.round(row.E_1 * 1000000).astype(int)})) + G.add_edges_from(E) + + # For each D of last cycle connect to sink + E = [] + for idx, row in Dvar_tmp[(Dvar_tmp.r == rounds.max())].iterrows(): + E.append((list(G.successors(row.X_idx))[0], sink, {'capacity': 1, 'weight': 0})) + G.add_edges_from(E) + + # Prune graph removing leaf nodes + remove_nodes = [] + for n in G.nodes: + n_set = nx.algorithms.descendants(G, n) + if sink not in n_set: + remove_nodes.append(n) + if n == 0: # source and sink are not connected + return {'G': None, 'Dvar': None, 'Tvar': None} + + remove_nodes.remove(sink) + G.remove_nodes_from(remove_nodes) + + MaxFlowMinCost = nx.max_flow_min_cost(G, 0, sink) + # Decode sequence + E = [] + for n1 in MaxFlowMinCost: + for n2 in MaxFlowMinCost[n1]: + if MaxFlowMinCost[n1][n2] == 1: + E.append((int(n1), n2, {})) + G = nx.Graph() + G.add_edges_from(E) + G.remove_node(0) + G.remove_node(sink) + + return {'G': G, 'Dvar': Dvar_tmp, 'Tvar': Tvar_tmp} + else: + return {'G': None, 'Dvar': None, 'Tvar': None} + else: + return {'G': None, 'Dvar': None, 'Tvar': None} + + +def _build_intensity_table_graph_results(intensity_tables: Dict[int, IntensityTable], + rounds: Sequence[int], + search_radius: int, + search_radius_max: int, + k_d: float, + anchor_round: int + ) -> IntensityTable: + """Construct an intensity table from the results of a graph based search of detected spots + + Parameters + ---------- + intensity_tables : Dict[int, IntensityTable] + Output from _merge_spots_by_round, contains mapping of intensity tables + from each round to all the spots detected in them. + channels, rounds : Sequence[int] + Channels and rounds present in the ImageStack from which spots were detected. + search_radius : int + Euclidean distance in pixels over which to search for spots in subsequent rounds. + search_radius_max : int + The maximum (euclidian) distance in pixels allowed between nodes belonging + to the same sequence + k_d : float + distance weight + anchor_round : int + The imaging round to seed the search from. + + Returns + ------- + IntensityTable + Intensity table from the results of a graph based search of detected spots + """ + + anchor_intensity_table = intensity_tables[anchor_round] + data = pd.DataFrame() + for i in intensity_tables: + data = data.append( + pd.DataFrame({Axes.X.value: intensity_tables[i][Axes.X.value].values, + Axes.Y.value: intensity_tables[i][Axes.Y.value].values, + Axes.ZPLANE.value: intensity_tables[i][Axes.ZPLANE.value].values, + Axes.CH.value: np.argmax(intensity_tables[i].fillna( + 0).values, axis=1)[:, i], + Axes.ROUND.value: i, + 'Imax_gf': np.amax(intensity_tables[i].fillna(0).values, + axis=1)[:, i], + 'p1': intensity_tables[i]['Q'].values, + 'p0': 1 - intensity_tables[i]['Q'].values, + 'feature_id': intensity_tables[i].features.values}), + ignore_index=True) + + res = _runGraphBuilder(data, search_radius, k_d, search_radius_max) + idx = _baseCalling(res, rounds, search_radius_max) + + # Initialize IntensityTable with anchor round IntensityTable + intensity_table = anchor_intensity_table.drop('Q') + + # fill IntensityTable + if len(idx): + for r in rounds: + # need numpy indexing to set values in vectorized manner + intensity_table.values[ + idx[:, anchor_round], :, r] = intensity_tables[r].values[idx[:, r], :, r] + + return IntensityTable(intensity_table) + + def _merge_spots_by_round( spot_results: SpotFindingResults ) -> Dict[int, pd.DataFrame]: diff --git a/starfish/core/spots/FindSpots/__init__.py b/starfish/core/spots/FindSpots/__init__.py index 4b62b0db3..f08d8d914 100644 --- a/starfish/core/spots/FindSpots/__init__.py +++ b/starfish/core/spots/FindSpots/__init__.py @@ -1,5 +1,6 @@ from ._base import FindSpotsAlgorithm from .blob import BlobDetector +from .h_max import HMax from .local_max_peak_finder import LocalMaxPeakFinder from .trackpy_local_max_peak_finder import TrackpyLocalMaxPeakFinder diff --git a/starfish/core/spots/FindSpots/h_max.py b/starfish/core/spots/FindSpots/h_max.py new file mode 100644 index 000000000..cffaaa899 --- /dev/null +++ b/starfish/core/spots/FindSpots/h_max.py @@ -0,0 +1,145 @@ +from functools import partial +from typing import Optional, Union + +import numpy as np +import pandas as pd +import xarray as xr +from scipy.spatial import distance +from skimage import img_as_float +from skimage.measure import label +from skimage.morphology import h_maxima + +from starfish.core.imagestack.imagestack import ImageStack +from starfish.core.spots.FindSpots import spot_finding_utils +from starfish.core.types import Axes, Features, SpotAttributes, SpotFindingResults +from ._base import FindSpotsAlgorithm + + +class HMax(FindSpotsAlgorithm): + """ + Determine all maxima of the image with height >= h. + + Parameters + ---------- + h : unsigned integer + The minimal height of all extracted maxima. + selem : ndarray, optional + The neighborhood expressed as an n-D array of 1's and 0's. + Default is the ball of radius 1 according to the maximum norm + (i.e. a 3x3 square for 2D images, a 3x3x3 cube for 3D images, etc.) + + Notes + ----- + https://scikit-image.org/docs/dev/api/skimage.morphology + """ + + def __init__(self, h, selem: np.ndarray=None, is_volume=True, measurement_type='max'): + self.h = h + self.selem = selem + self.is_volume = is_volume + self.measurement_function = self._get_measurement_function(measurement_type) + + def image_to_spots(self, data_image: Union[np.ndarray, xr.DataArray]) -> SpotAttributes: + """ + Find spots using a h_maxima algorithm + + Parameters + ---------- + data_image : Union[np.ndarray, xr.DataArray] + image containing spots to be detected + + Returns + ------- + SpotAttributes : + DataFrame of metadata containing the coordinates, intensity and radius of each spot + + """ + + results = h_maxima(image=img_as_float(data_image), h=self.h, selem=self.selem) + + data_image = np.asarray(data_image) + + label_h_max = label(results, neighbors=4) + # no maxima present in image + if (label_h_max == np.ones(label_h_max.shape)).all(): + max_mask = np.zeros(data_image.shape) + else: + labels = pd.DataFrame( + data={'labels': np.sort(label_h_max[np.where(label_h_max != 0)])}) + # find duplicates labels (=connected components) + dup = labels.index[labels.duplicated()].tolist() + + # splitting connected regional maxima to get only one local maxima + max_mask = np.zeros(data_image.shape) + max_mask[label_h_max != 0] = 1 + + # Compute medoid for connected regional maxima + for i in range(len(dup)): + # find coord of points having the same label + z, r, c = np.where(label_h_max == labels.loc[dup[i], 'labels']) + meanpoint_x = np.mean(c) + meanpoint_y = np.mean(r) + meanpoint_z = np.mean(z) + dist = [distance.euclidean([meanpoint_z, meanpoint_y, meanpoint_x], + [z[j], r[j], c[j]]) for j in range(len(r))] + ind = dist.index(min(dist)) + # delete values at ind position. + z, r, c = np.delete(z, ind), np.delete(r, ind), np.delete(c, ind) + max_mask[z, r, c] = 0 # set to 0 points != medoid coordinates + results = max_mask.nonzero() + results = np.vstack(results).T + + spot_data = pd.DataFrame( + data={Axes.X.value: results[:, 2], + Axes.Y.value: results[:, 1], + Axes.ZPLANE.value: results[:, 0], + Features.SPOT_RADIUS: 1, + Features.SPOT_ID: np.arange(results.shape[0]), + Features.INTENSITY: data_image[results[:, 0], + results[:, 1], + results[:, 2]] + }) + return SpotAttributes(spot_data) + + def run( + self, + image_stack: ImageStack, + reference_image: Optional[ImageStack] = None, + n_processes: Optional[int] = None, + *args, + ) -> SpotFindingResults: + """ + Find spots in the given ImageStack using a gaussian blob finding algorithm. + If a reference image is provided the spots will be detected there then measured + across all rounds and channels in the corresponding ImageStack. If a reference_image + is not provided spots will be detected _independently_ in each channel. This assumes + a non-multiplex imaging experiment, as only one (ch, round) will be measured for each spot. + + Parameters + ---------- + image_stack : ImageStack + ImageStack where we find the spots in. + reference_image : xr.DataArray + (Optional) a reference image. If provided, spots will be found in this image, and then + the locations that correspond to these spots will be measured across each channel. + n_processes : Optional[int] = None, + Number of processes to devote to spot finding. + """ + spot_finding_method = partial(self.image_to_spots, *args) + if reference_image: + data_image = reference_image._squeezed_numpy(*{Axes.ROUND, Axes.CH}) + reference_spots = spot_finding_method(data_image) + results = spot_finding_utils.measure_intensities_at_spot_locations_across_imagestack( + data_image=image_stack, + reference_spots=reference_spots, + measurement_function=self.measurement_function) + else: + spot_attributes_list = image_stack.transform( + func=spot_finding_method, + group_by={Axes.ROUND, Axes.CH}, + n_processes=n_processes + ) + results = SpotFindingResults(imagestack_coords=image_stack.xarray.coords, + log=image_stack.log, + spot_attributes_list=spot_attributes_list) + return results diff --git a/starfish/core/spots/FindSpots/test/test_spot_detection.py b/starfish/core/spots/FindSpots/test/test_spot_detection.py index 01ec4e54e..f06b484b9 100644 --- a/starfish/core/spots/FindSpots/test/test_spot_detection.py +++ b/starfish/core/spots/FindSpots/test/test_spot_detection.py @@ -11,6 +11,7 @@ from starfish.types import Axes, FunctionSource from .._base import FindSpotsAlgorithm from ..blob import BlobDetector +from ..h_max import HMax from ..local_max_peak_finder import LocalMaxPeakFinder from ..trackpy_local_max_peak_finder import TrackpyLocalMaxPeakFinder @@ -54,10 +55,17 @@ def simple_local_max_spot_detector() -> LocalMaxPeakFinder: threshold=0 ) + +def simple_h_max_detector() -> HMax: + return HMax( + h=0.5 + ) + # initialize spot detectors gaussian_spot_detector = simple_gaussian_spot_detector() trackpy_local_max_spot_detector = simple_trackpy_local_max_spot_detector() local_max_spot_detector = simple_local_max_spot_detector() +h_max_detector = simple_h_max_detector() # test parameterization test_parameters = ( @@ -66,12 +74,15 @@ def simple_local_max_spot_detector() -> LocalMaxPeakFinder: (ONE_HOT_IMAGESTACK, gaussian_spot_detector), (ONE_HOT_IMAGESTACK, trackpy_local_max_spot_detector), (ONE_HOT_IMAGESTACK, local_max_spot_detector), + (ONE_HOT_IMAGESTACK, h_max_detector), (SPARSE_IMAGESTACK, gaussian_spot_detector), (SPARSE_IMAGESTACK, trackpy_local_max_spot_detector), (SPARSE_IMAGESTACK, local_max_spot_detector), + (SPARSE_IMAGESTACK, h_max_detector), (BLANK_IMAGESTACK, gaussian_spot_detector), (BLANK_IMAGESTACK, trackpy_local_max_spot_detector), (BLANK_IMAGESTACK, local_max_spot_detector), + (BLANK_IMAGESTACK, h_max_detector), ] ) diff --git a/starfish/core/types/_constants.py b/starfish/core/types/_constants.py index 9ebcf15d4..e5f6dd7b1 100644 --- a/starfish/core/types/_constants.py +++ b/starfish/core/types/_constants.py @@ -111,5 +111,6 @@ class TraceBuildingStrategies(AugmentedEnum): currently support spot trace building strategies """ EXACT_MATCH = 'exact_match' + GRAPH = 'graph' NEAREST_NEIGHBOR = 'nearest_neighbor' SEQUENTIAL = 'sequential'