Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-39622: Add plotting for images. #59

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions python/lsst/summit/utils/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# This file is part of summit_utils.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import numpy as np

from lsst.afw.detection import FootprintSet
from lsst.afw.table import SourceCatalog
import lsst.geom as geom
from lsst.summit.utils import getQuantiles
import lsst.afw.image as afwImage

import matplotlib.pyplot as plt
import matplotlib.colors as colors

import astropy.visualization as vis


def plot(inputData,
figure=None,
centroids=None,
title=None,
showCompass=False,
stretch='linear',
percentile=99.,
cmap='gray_r',
compassLocation=250,
addLegend=True,
savePlotAs=None):

"""Make a plot.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be significantly more descriptive, describing what the function does in some detail, e.g. that it supports many different types of images, shows compasses if the image has a wcs, etc.


Parameters
----------
inputData : `numpy.array`, `lsst.afw.image.Exposure`,
`lsst.afw.image.Image`, or `lsst.afw.image.MaskedImage`
The input data.
imageType : `str`, optional
If input data is an exposure, plot either 'image', or 'masked' image.
Defaults to 'image'.
madamow marked this conversation as resolved.
Show resolved Hide resolved
ax : `matplotlib.axes.Axes`, optional
The Matplotlib axis containing the image data plot.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't contain it here, it will be used for plotting.

centroids : `list`
The centroids parameter represents a collection of centroid data.
It can be a combination of different types of data:

- List of tuples: Each tuple is a centroid with its (X,Y) coordinates.
- FootprintSet: lsst.afw.detection.FootprintSet object.
- SourceCatalog: A lsst.afw.table.SourceCatalog object.

You can provide any combination of these data types within the list.
The function will plot the centroid data accordingly.
title : `str`, optional
Title for the plot.
showCompass : `bool`, optional
Add compass to the plot? Defaults to False.
stretch : `str', optional
Changes mapping of colors for the image. Avaliable options:
ccs, log, power, asinh, linear, sqrt. Defaults to linear.
percentile : `float', optional
Parameter for astropy.visualization.PercentileInterval:
The fraction of pixels to keep. The same fraction of pixels
is eliminated from both ends. Here: defaults to 99.
Comment on lines +79 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing here is really being kept/eliminated here, though I see what you mean. It's just setting level in the image which will map to black/white (or some other colours depending on the colour map). I think more people are familiar enough with the concept of stretching to just say that this is used for the min/max of the stretch.

cmap : `str`, optional
matplotlib colormap. Defaults to 'gray_r'.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use full sentences to describe what this are.

compassLocation : `int`, optional
How far in from the bottom left of the image to display the compass.
By default, compass will be placed at pixel (X,Y) = (250,250).
addLegend : `bool', optional
Add legend to the plot.
savePlotAs : `str`, optional
The name of the file to save the plot as, including the file extension.
The extention must be supported by `matplotlib.pyplot`.
If None (default) plot will not be saved.

Returns
-------
fig : `matplotlib.figure.Figure`
The rendered image.
"""

if not figure:
figure = plt.figure(figsize=(10, 10))
ax = figure.add_subplot(111)

match inputData:
case np.ndarray():
imageData = inputData
case afwImage.MaskedImage():
imageData = inputData.image.array
case afwImage.Image():
imageData = inputData.array
case afwImage.Exposure():
imageData = inputData.image.array
case _:
raise TypeError("This function accepts numpy array, lsst.afw.image.Exposure components. "
"Got type(inputData)")

match stretch:
case 'ccs':
quantiles = getQuantiles(imageData, 256)
norm = colors.BoundaryNorm(quantiles, 256)
case 'asinh':
norm = vis.ImageNormalize(imageData,
interval=vis.PercentileInterval(percentile),
madamow marked this conversation as resolved.
Show resolved Hide resolved
stretch=vis.AsinhStretch(a=0.1))
case 'power':
norm = vis.ImageNormalize(imageData,
interval=vis.PercentileInterval(percentile),
stretch=vis.PowerStretch(a=2))
case 'log':
norm = vis.ImageNormalize(imageData,
interval=vis.PercentileInterval(percentile),
stretch=vis.LogStretch(a=1))
case 'linear':
norm = vis.ImageNormalize(imageData,
interval=vis.PercentileInterval(percentile),
stretch=vis.LinearStretch())
case 'sqrt':
norm = vis.ImageNormalize(imageData,
interval=vis.PercentileInterval(percentile),
stretch=vis.SqrtStretch())
case _:
raise ValueError(f"Invalid value for stretch : {stretch}. "
"Accepted options are: ccs, asinh, power, log, linear, sqrt.")

ax.imshow(imageData, cmap=cmap, origin='lower', norm=norm)

if showCompass:
color = 'r'
try:
wcs = inputData.getWcs()
except AttributeError:
wcs = None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in this block, if we have compass set to True and we didn't manage to get a WCS there should be a logger warning to let the user know that they've failed. If the object was the wrong type then I think maybe that could pass silently, but if they passed in an exposure and failed to get a compass to plot then that shouldn't be silent.

if wcs:
anchorRa, anchorDec = wcs.pixelToSky(compassLocation, compassLocation)
east = wcs.skyToPixel(geom.SpherePoint(anchorRa + 30.0 * geom.arcseconds, anchorDec))
north = wcs.skyToPixel(geom.SpherePoint(anchorRa, anchorDec + 30. * geom.arcseconds))

ax.arrow(compassLocation, compassLocation,
north[0]-compassLocation, north[1]-compassLocation,
head_width=1., head_length=1., color=color)
ax.arrow(compassLocation, compassLocation,
east[0]-compassLocation, east[1]-compassLocation,
head_width=1., head_length=1., color=color)
ax.text(north[0], north[1], 'N', color=color)
ax.text(east[0], east[1], 'E', color=color)

# Add centroids
if centroids:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this whole approach to dealing with multiple types of sources to plot might not be the most advisable. I think it might be more simple to have a footprints arg which accepts either a list of footprints or a footprintSet, a centroid arg which takes a list of centroids, and a sourceCat arg which accepts sourceCatalogs. I think that would make it easier for users, and make this code a little less contrived and delicate.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will rewrite this part of the code. I did not like it from the beginning.

cCycle = ['r', 'b', 'g', 'c', 'm', 'y', 'k']
# Create a dict with points from different sources
cenDict = {}
c_fs, c_sc, c_lst = 0, 0, 0 # index for color
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to avoid mixing camelCase and snake_case in the same code.

for cenSet in centroids:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite the API we discussed. If someone just passes a source catalog then this won't work, or likewise, just a footprintSet, etc - they would need to put them in a [] wrapper, which isn't very user friendly. You can use from lsst.utils.iteration import ensure_iterable and call that on the input if you like though, I think that will fix most of these cases.

match cenSet:
case FootprintSet():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we'd discussed this a few times via direct message, but this still doesn't accept a list of footprints as well as a FootprintSet. Please make it accept both.

fs = FootprintSet.getFootprints(cenSet)
xy = [_.getCentroid() for _ in fs]
key = 'footprintSet'+str(c_fs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a place for f-strings rather than calling str() on the counter.

cenDict[key] = {'data': xy}
cenDict[key]['m'] = '+'
cenDict[key]['c'] = cCycle[c_fs]
c_fs += 1
case SourceCatalog():
key = 'SourceCatalog'+str(c_sc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise f-strings here too.

xy = list(zip(cenSet.getX(), cenSet.getY()))
cenDict[key] = {'data': xy}
cenDict[key]['m'] = 'x'
cenDict[key]['c'] = cCycle[c_sc]
c_sc += 1
case list():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is broken for taking a list of centroid tuples, if I pass centroids=[(100,100), (200,200), (300,300)] it will raise the TypeError, I'd need to pass centroids=[[(100,100), (200,200), (300,300)]] which is not very user-friendly.

key = 'tupleList'+str(c_lst)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here.

cenDict[key] = {'data': cenSet}
cenDict[key]['m'] = 's'
cenDict[key]['c'] = cCycle[c_lst]
c_lst += 1
case _:
raise TypeError("This function accepts a list of SourceCatalog, \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This trailing \ is wrong: firstly, we don't use them - you don't need them when you're already inside parens (as you can see from your own code two lines later), but moreover, you've not closed the " so this is adding tons of spaces to the error message.

list of tuples, or FootprintSet. "
f"Got {type(cenSet)}: {cenSet}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As before, I don't think you want the : {cenSet} part, we don't need the centroids themselves printed in the error.


for cSet in cenDict:
ax.plot(*zip(*cenDict[cSet]['data']),
marker=cenDict[cSet]['m'],
markeredgecolor=cenDict[cSet]['c'],
markerfacecolor='None',
linestyle='None', label=cSet)

if addLegend:
ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=5)
if title:
ax.set_title(title)
if savePlotAs:
plt.savefig(savePlotAs)

return figure
97 changes: 97 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# This file is part of summit_utils.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import unittest
import tempfile
import os

import lsst.utils.tests

from lsst.summit.utils.butlerUtils import makeDefaultLatissButler
from lsst.summit.utils.plotting import plot


class PlottingTestCase(lsst.utils.tests.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to add tests for the footprint plotting interface here too, as well as the raw centroid one: feel free to run the detection on the image using from lsst.summit.utils.utils import detectObjectsInExp with a low enough threshold to ensure that you get some sources. That way you can test the FootprintSet interface, as well as the (yet to be written) list of footprints one.


@classmethod
def setUpClass(cls):
try:
cls.butler = makeDefaultLatissButler()
except FileNotFoundError:
raise unittest.SkipTest("Skipping tests that require the LATISS butler repo.")
Comment on lines +34 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is done elsewhere in the tests, but we're moving away from this model. There are some test images in the afw_data package which could be used instead of getting things with a butler, or alternatively there's the new testing stuff that Hsin Fang has been writing. I don't mind which you switch to, but I think this type of testing shouldn't be used unless it's strictly necessary.


# Chosen to work on the TTS, summit and NCSA
cls.dataId = {'day_obs': 20200315, 'seq_num': 120, 'detector': 0}
cls.outputDir = tempfile.mkdtemp()

def test_plot(self):
"""Test that the the plot is made and saved
"""
exp = self.butler.get('raw', self.dataId)
centroids = [(567, 746), (576, 599), (678, 989)]

# Input is an exposure
outputFilename = os.path.join(self.outputDir, 'testPlotting_exp.jpg')
plot(exp,
centroids=[centroids],
showCompass=True,
savePlotAs=outputFilename)
self.assertTrue(os.path.isfile(outputFilename))
self.assertTrue(os.path.getsize(outputFilename) > 10000)

# Input is a numpy array
outputFilename = os.path.join(self.outputDir, 'testPlotting_nparr.jpg')
nparr = exp.image.array
plot(nparr,
showCompass=True,
centroids=[centroids],
savePlotAs=outputFilename)
self.assertTrue(os.path.isfile(outputFilename))
self.assertTrue(os.path.getsize(outputFilename) > 10000)

# Input is an image
outputFilename = os.path.join(self.outputDir, 'testPlotting_image.jpg')
im = exp.image
plot(im,
showCompass=True,
centroids=[centroids],
savePlotAs=outputFilename)
self.assertTrue(os.path.isfile(outputFilename))
self.assertTrue(os.path.getsize(outputFilename) > 10000)

# Input is a masked image
outputFilename = os.path.join(self.outputDir, 'testPloting_mask.jpg')
masked = exp.maskedImage
plot(masked,
showCompass=True,
centroids=[centroids],
savePlotAs=outputFilename)
self.assertTrue(os.path.isfile(outputFilename))
self.assertTrue(os.path.getsize(outputFilename) > 10000)


def setup_module(module):
lsst.utils.tests.init()


if __name__ == "__main__":
lsst.utils.tests.init()
unittest.main()