-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add mid-level driver for measurement algorithms
- Loading branch information
1 parent
32e5bea
commit d0cbde4
Showing
2 changed files
with
380 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,379 @@ | ||
# This file is part of meas_algorithms. | ||
# | ||
# Developed for the LSST Data Management System. | ||
# This product includes software developed by the LSST Project | ||
# (https://www.lsst.org). | ||
# See the COPYRIGHT file at the top-level directory of this distribution | ||
# for details of code ownership. | ||
# | ||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# (at your option) any later version. | ||
# | ||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU General Public License | ||
# along with this program. If not, see <https://www.gnu.org/licenses/>. | ||
|
||
__all__ = ["MeasurementDriverConfig", "MeasurementDriverTask"] | ||
|
||
import logging | ||
|
||
import lsst.afw.image as afwImage | ||
import lsst.afw.table as afwTable | ||
import lsst.meas.algorithms as measAlgorithms | ||
import lsst.meas.base as measBase | ||
import lsst.meas.deblender as measDeblender | ||
import lsst.meas.extensions.scarlet as scarlet | ||
import lsst.pex.config as pexConfig | ||
import lsst.pipe.base as pipeBase | ||
import numpy as np | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
class MeasurementDriverConfig(pexConfig.Config): | ||
"""Configuration parameters for `MeasurementDriverTask`.""" | ||
|
||
# To generate catalog ids consistently across subtasks. | ||
id_generator = measBase.DetectorVisitIdGeneratorConfig.make_field() | ||
|
||
detection = pexConfig.ConfigurableField( | ||
target=measAlgorithms.SourceDetectionTask, | ||
doc="Task to detect sources to return in the output catalog.", | ||
) | ||
|
||
deblender = pexConfig.ChoiceField[str]( | ||
doc="The deblender to use.", | ||
default="meas_deblender", | ||
allowed={"meas_deblender": "Deblend using meas_deblender", "scarlet": "Deblend using scarlet"}, | ||
) | ||
|
||
deblend = pexConfig.ConfigurableField( | ||
target=measDeblender.SourceDeblendTask, doc="Split blended sources into their components." | ||
) | ||
|
||
measurement = pexConfig.ConfigurableField( | ||
target=measBase.SingleFrameMeasurementTask, | ||
doc="Task to measure sources to return in the output catalog.", | ||
) | ||
|
||
def __setattr__(self, key, value): | ||
"""Intercept attribute setting to trigger setDefaults when relevant | ||
fields change. | ||
""" | ||
super().__setattr__(key, value) | ||
|
||
# This is to ensure the deblend target is set correctly whenever the | ||
# deblender is changed. This is required because `setDefaults` is not | ||
# automatically invoked during reconfiguration. | ||
if key == "deblender": | ||
self.setDefaults() | ||
|
||
def validate(self): | ||
super().validate() | ||
|
||
# Ensure the deblend target aligns with the selected deblender. | ||
if self.deblender == "scarlet": | ||
assert self.deblend.target == scarlet.ScarletDeblendTask | ||
elif self.deblender == "meas_deblender": | ||
assert self.deblend.target == measDeblender.SourceDeblendTask | ||
elif self.deblender is not None: | ||
raise ValueError(f"Invalid deblender value: {self.deblender}") | ||
|
||
def setDefaults(self): | ||
super().setDefaults() | ||
if self.deblender == "scarlet": | ||
self.deblend.retarget(scarlet.ScarletDeblendTask) | ||
elif self.deblender == "meas_deblender": | ||
self.deblend.retarget(measDeblender.SourceDeblendTask) | ||
|
||
|
||
class MeasurementDriverTask(pipeBase.Task): | ||
"""A mid-level driver for running detection, deblending (optional), and | ||
measurement algorithms in one go. | ||
This driver simplifies the process of applying a small set of measurement | ||
algorithms to images by abstracting away schema and table boilerplate. It | ||
is particularly suited for simple use cases, such as processing images | ||
without neighbor-noise-replacement or extensive configuration. | ||
Designed to streamline the measurement framework, this class integrates | ||
detection, deblending (if enabled), and measurement into a single workflow. | ||
Parameters | ||
---------- | ||
schema : `~lsst.afw.table.Schema` | ||
Schema used to create the output `~lsst.afw.table.SourceCatalog`, | ||
modified in place with fields that will be written by this task. | ||
**kwargs : `dict` | ||
Additional kwargs to pass to lsst.pipe.base.Task.__init__() | ||
Examples | ||
-------- | ||
Here is an example of how to use this class to run detection, deblending, | ||
and measurement on a given exposure: | ||
>>> from lsst.meas.algorithms import MeasurementDriverTask | ||
>>> config = MeasurementDriverTask().ConfigClass() | ||
>>> config.detection.thresholdValue = 5.5 | ||
>>> config.deblender = "meas_deblender" | ||
>>> config.deblend.tinyFootprintSize = 3 | ||
>>> config.measurement.plugins.names |= [ | ||
... "base_SdssCentroid", | ||
... "base_SdssShape", | ||
... "ext_shapeHSM_HsmSourceMoments", | ||
... ] | ||
>>> config.measurement.slots.psfFlux = None | ||
>>> config.measurement.doReplaceWithNoise = False | ||
>>> exposure = butler.get("deepCoadd", dataId=...) | ||
>>> driver = MeasurementDriverTask(config=config) | ||
>>> catalog = driver.run(exposure) | ||
>>> catalog.writeFits("meas_catalog.fits") | ||
""" | ||
|
||
ConfigClass = MeasurementDriverConfig | ||
_DefaultName = "measurementDriver" | ||
|
||
def __init__(self, schema=None, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
# Create a minimal schema that will be extended by tasks, if not given. | ||
if schema is None: | ||
self.schema = afwTable.SourceTable.makeMinimalSchema() | ||
else: | ||
self.schema = schema | ||
|
||
# Add coordinate error fields to the schema (this is to avoid errors | ||
# such as: "Field with name 'coord_raErr' not found with type 'F'"). | ||
afwTable.CoordKey.addErrorFields(self.schema) | ||
|
||
self.subtasks = ["detection", "deblend", "measurement"] | ||
|
||
def make_subtasks(self): | ||
"""Create subtasks based on the current configuration.""" | ||
for name in self.subtasks: | ||
self.makeSubtask(name, schema=self.schema) | ||
|
||
def run( | ||
self, | ||
image, | ||
bands=None, | ||
band=None, | ||
mask=None, | ||
variance=None, | ||
psf=None, | ||
wcs=None, | ||
photo_calib=None, | ||
id_generator=None, | ||
): | ||
"""Run detection, optional deblending, and measurement on a given | ||
image. | ||
Parameters | ||
---------- | ||
image: `~lsst.afw.image.Exposure` or `~lsst.afw.image.MaskedImage` or | ||
`~lsst.afw.image.Image` or `np.ndarray` or | ||
`~lsst.afw.image.MultibandExposure` or | ||
`list` of `~lsst.afw.image.Exposure` | ||
The image on which to detect, deblend and measure sources. If | ||
provided as a multiband exposure, or a list of `Exposure` objects, | ||
it can only be used with the 'scarlet' deblender. When using a list | ||
of `Exposure` objects, the ``bands`` parameter must also be | ||
provided for scarlet deblending. | ||
bands: `str` or `list` of `str`, optional | ||
The band(s) of the image. Required if ``image`` is provided as a | ||
list of `Exposure` objects to use in scarlet deblending. Example: | ||
["g", "r", "i", "z", "y"] or "grizy". | ||
band: `str`, optional | ||
The target band of the image to use for detection and measurement. | ||
Required for scarlet deblending when ``image`` is provided as a | ||
`MultibandExposure`, or a list of `Exposure` objects. | ||
mask: `~lsst.afw.image.Mask`, optional | ||
The mask to use for detection. Will be ignored if ``image`` is | ||
provided as a `MaskedImage`, a `MultibandExposure`, an `Exposure` | ||
, or a list of `Exposure` objects. | ||
variance: `~lsst.afw.image.Image`, optional | ||
The variance image to use for measurement. Will be ignored if | ||
``image`` is provided as a `MaskedImage`, a `MultibandExposure`, an | ||
`Exposure`, or a list of `Exposure` objects. | ||
psf: `~lsst.afw.detection.Psf`, optional | ||
The PSF model to use for measurement. Will be ignored if ``image`` | ||
is provided as a `MultibandExposure`, an `Exposure`, or a list of | ||
`Exposure` objects. | ||
wcs: `~lsst.afw.image.Wcs`, optional | ||
The World Coordinate System (WCS) model to use for measurement. | ||
Will be ignored if ``image`` is provided as a `MultibandExposure`, | ||
an `Exposure`, or a list of `Exposure` objects. | ||
photo_calib : `~lsst.afw.image.PhotoCalib`, optional | ||
Photometric calibration model to use for measurement. Will be | ||
ignored if ``image`` is provided as a `MultibandExposure`, an | ||
`Exposure`, or a list of `Exposure` objects. | ||
id_generator : `~lsst.meas.base.IdGenerator`, optional | ||
Object that generates source IDs and provides random seeds. | ||
Returns | ||
------- | ||
catalog : `~lsst.afw.table.SourceCatalog` | ||
The source catalog with all requested measurements. | ||
""" | ||
|
||
# Only make the `deblend` subtask if it is enabled. | ||
if self.config.deblender is None: | ||
self.subtasks.remove("deblend") | ||
|
||
# Validate the configuration before running the task. | ||
self.config.validate() | ||
|
||
# This guarantees the `run` method picks up the current subtask config. | ||
self.make_subtasks() | ||
# N.B. subtasks must be created here to handle reconfigurations, such | ||
# as retargeting the `deblend` subtask, because the `makeSubtask` | ||
# method locks in its config just before creating the subtask. If the | ||
# subtask was already made in __init__ using the initial config, it | ||
# cannot be retargeted now because retargeting happens to the config. | ||
|
||
if id_generator is None: | ||
id_generator = measBase.IdGenerator() | ||
|
||
if isinstance(image, afwImage.MultibandExposure) or isinstance(image, list): | ||
if self.config.deblender != "scarlet": | ||
self.log.debug( | ||
"Supplied a multiband exposure, or a list of exposures, while the deblender is set to " | ||
f"'{self.config.deblender}'. A single exposure corresponding to target `band` will be " | ||
"used." | ||
) | ||
if band is None: | ||
raise ValueError( | ||
"The target `band` must be provided when using multiband exposures or a list of " | ||
"exposures." | ||
) | ||
if isinstance(image, list): | ||
if not all(isinstance(im, afwImage.Exposure) for im in image): | ||
raise ValueError("All elements in the `image` list must be `Exposure` objects.") | ||
if bands is None: | ||
raise ValueError( | ||
"The `bands` parameter must be provided if `image` is a list of `Exposure` objects." | ||
) | ||
if not isinstance(bands, (str, list)) or ( | ||
isinstance(bands, list) and not all(isinstance(b, str) for b in bands) | ||
): | ||
raise TypeError( | ||
"The `bands` parameter must be a string or a list of strings if provided." | ||
) | ||
if len(bands) != len(image): | ||
raise ValueError( | ||
"The number of bands must match the number of `Exposure` objects in the list." | ||
) | ||
else: | ||
if band is None: | ||
band = "N/A" # Just a placeholder for single-band deblending | ||
else: | ||
self.log.warn("The target `band` is not required when the input image is not multiband.") | ||
if bands is not None: | ||
self.log.warn( | ||
"The `bands` parameter will be ignored because the input image is not multiband." | ||
) | ||
|
||
if self.config.deblender == "scarlet": | ||
if not isinstance(image, (afwImage.MultibandExposure, list, afwImage.Exposure)): | ||
raise ValueError( | ||
"The `image` parameter must be a `MultibandExposure`, a list of `Exposure` " | ||
"objects, or a single `Exposure` when the deblender is set to 'scarlet'." | ||
) | ||
if isinstance(image, afwImage.Exposure): | ||
# N.B. scarlet is designed to leverage multiband information to | ||
# differentiate overlapping sources based on their spectral and | ||
# spatial profiles. However, it can also run on a single band | ||
# and still give better results than 'meas_deblender'. | ||
self.log.debug( | ||
"Supplied a single-band exposure, while the deblender is set to 'scarlet'." | ||
"Make sure it was intended." | ||
) | ||
|
||
# Start with some image conversions if needed. | ||
if isinstance(image, np.ndarray): | ||
image = afwImage.makeImageFromArray(image) | ||
if isinstance(mask, np.ndarray): | ||
mask = afwImage.makeMaskFromArray(mask) | ||
if isinstance(variance, np.ndarray): | ||
variance = afwImage.makeImageFromArray(variance) | ||
if isinstance(image, afwImage.Image): | ||
image = afwImage.makeMaskedImage(image, mask, variance) | ||
|
||
# Avoid type checker errors by being explicit from here on. | ||
exposure: afwImage.Exposure | ||
|
||
# Make sure we have an `Exposure` object to work with (potentially | ||
# along with a `MultiBandExposure` for scarlet deblending). | ||
if isinstance(image, afwImage.Exposure): | ||
exposure = image | ||
elif isinstance(image, afwImage.MaskedImage): | ||
exposure = afwImage.makeExposure(image, wcs) | ||
if psf is not None: | ||
exposure.setPsf(psf) | ||
if photo_calib is not None: | ||
exposure.setPhotoCalib(photo_calib) | ||
elif isinstance(image, list): | ||
# Construct a multiband exposure for scarlet deblending. | ||
exposures = afwImage.MultibandExposure.fromExposures(bands, image) | ||
# Select the exposure of the desired band, which will be used for | ||
# detection and measurement. | ||
exposure = exposures[band] | ||
elif isinstance(image, afwImage.MultibandExposure): | ||
exposures = image | ||
exposure = exposures[band] | ||
else: | ||
raise TypeError(f"Unsupported image type: {type(image)}") | ||
|
||
# Create a source table into which detections will be placed. | ||
table = afwTable.SourceTable.make(self.schema, id_generator.make_table_id_factory()) | ||
|
||
# Detect sources and get a source catalog. | ||
self.log.info(f"Running detection on a {exposure.width}x{exposure.height} pixel image") | ||
detections = self.detection.run(table, exposure) | ||
catalog = detections.sources | ||
|
||
# Deblend sources into their components and update the catalog. | ||
if self.config.deblender is None: | ||
self.log.info("Deblending is disabled; skipping deblending") | ||
else: | ||
self.log.info( | ||
f"Running deblending via '{self.config.deblender}' on {len(catalog)} detection footprints" | ||
) | ||
if self.config.deblender == "meas_deblender": | ||
self.deblend.run(exposure=exposure, sources=catalog) | ||
elif self.config.deblender == "scarlet": | ||
if not isinstance(image, (afwImage.MultibandExposure, list)): | ||
# We need to have a multiband exposure to satisfy scarlet | ||
# function's signature, even when only using a single band. | ||
exposures = afwImage.MultibandExposure.fromExposures([band], [exposure]) | ||
catalog, model_data = self.deblend.run(mExposure=exposures, mergedSources=catalog) | ||
# The footprints need to be updated for the subsequent | ||
# measurement to work. | ||
scarlet.io.updateCatalogFootprints( | ||
modelData=model_data, | ||
catalog=catalog, | ||
band=band, | ||
imageForRedistribution=exposure, | ||
removeScarletData=True, | ||
updateFluxColumns=True, | ||
) | ||
|
||
# The deblender may not produce a contiguous catalog; ensure contiguity | ||
# for the subsequent task. | ||
if not catalog.isContiguous(): | ||
self.log.info("Catalog is not contiguous; making it contiguous") | ||
catalog = catalog.copy(deep=True) | ||
|
||
# Measure requested quantities on sources. | ||
self.measurement.run(catalog, exposure) | ||
self.log.info( | ||
f"Measured {len(catalog)} sources and stored them in the output " | ||
f"catalog containing {catalog.schema.getFieldCount()} fields" | ||
) | ||
|
||
return catalog |