Skip to content

Commit

Permalink
Move LSST_PhotonPoolingImageBuilder into photon_pooling.py and make t…
Browse files Browse the repository at this point in the history
…he functions there its static methods.
  • Loading branch information
welucas2 committed Nov 7, 2024
1 parent c5c553a commit 5fc1eec
Show file tree
Hide file tree
Showing 3 changed files with 497 additions and 502 deletions.
1 change: 1 addition & 0 deletions imsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@
from .sag import *
from .process_info import *
from .table_row import *
from .photon_pooling import *
235 changes: 65 additions & 170 deletions imsim/lsst_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,11 @@
import galsim
from galsim.config import RegisterImageType, GetAllParams, GetSky, AddNoise
from galsim.config.image_scattered import ScatteredImageBuilder
from galsim.errors import GalSimConfigValueError

from .sky_model import SkyGradient, CCD_Fringing
from .camera import get_camera
from .vignetting import Vignetting
from .photon_pooling import (
merge_photon_arrays,
calc_offset_adjustment,
offset_photon_arrays,
accumulate_photons,
build_stamps,
create_full_image,
load_checkpoint,
load_objects,
make_batches,
make_photon_batches,
partition_objects,
save_checkpoint,
set_config_image_pos,
stamp_bounds,
)


class LSST_ImageBuilderBase(ScatteredImageBuilder):
def setup(self, config, base, image_num, obj_num, ignore, logger):
Expand Down Expand Up @@ -196,6 +180,68 @@ def addNoise(self, image, config, base, image_num, obj_num, current_var, logger)
image += sky
AddNoise(base,image,current_var,logger)

@staticmethod
def create_full_image(config, base):
"""Create the GalSim image on which we will place the individual
object stamps once they are drawn.
Parameters:
config: The configuration dictionary for the image field.
base: The base configuration dictionary.
Returns:
full_image: The galsim.Image representing the full field.
"""
if galsim.__version_info__ < (2,5):
# GalSim 2.4 required a bit more work here.
from galsim.config.stamp import _ParseDType

full_xsize = base['image_xsize']
full_ysize = base['image_ysize']
wcs = base['wcs']

dtype = _ParseDType(config, base)

full_image = galsim.Image(full_xsize, full_ysize, dtype=dtype)
full_image.setOrigin(base['image_origin'])
full_image.wcs = wcs
full_image.setZero()
base['current_image'] = full_image
else:
# In GalSim 2.5+, the image is already built and available as 'current_image'
full_image = base['current_image']
return full_image

@staticmethod
def set_config_image_pos(config, base):
"""Determine the image position if necessary using information
from the base configuration.
Parameters:
config: The configuration dictionary for the image field.
base: The base configuration dictionary.
"""

if 'image_pos' in config and 'world_pos' in config:
raise galsim.config.GalSimConfigValueError(
"Both image_pos and world_pos specified for LSST_Image.",
(config['image_pos'], config['world_pos']))

if ('image_pos' not in config and 'world_pos' not in config and
not ('stamp' in base and
('image_pos' in base['stamp'] or 'world_pos' in base['stamp']))):
full_xsize = base['image_xsize']
full_ysize = base['image_ysize']
xmin = base['image_origin'].x
xmax = xmin + full_xsize-1
ymin = base['image_origin'].y
ymax = ymin + full_ysize-1
config['image_pos'] = {
'type' : 'XY' ,
'x' : { 'type' : 'Random' , 'min' : xmin , 'max' : xmax },
'y' : { 'type' : 'Random' , 'min' : ymin , 'max' : ymax }
}


class LSST_ImageBuilder(LSST_ImageBuilderBase):
"""This is mostly the same as the GalSim "Scattered" image type.
Expand All @@ -221,7 +267,7 @@ def buildImage(self, config, base, image_num, obj_num, logger):
Returns:
the final image and the current noise variance in the image as a tuple
"""
set_config_image_pos(config, base)
self.set_config_image_pos(config, base)

full_image = None
start_num = obj_num
Expand Down Expand Up @@ -263,7 +309,7 @@ def buildImage(self, config, base, image_num, obj_num, logger):
photon_ops[rubin_optics_index]['shift_photons'] = True

if full_image is None:
full_image = create_full_image(config, base)
full_image = self.create_full_image(config, base)

# Ensure 1 <= nbatch <= nobj_tot
nbatch = max(min(self.nbatch, nobj_tot), 1)
Expand Down Expand Up @@ -323,155 +369,4 @@ def buildImage(self, config, base, image_num, obj_num, logger):
return full_image, current_var


class LSST_PhotonPoolingImageBuilder(LSST_ImageBuilderBase):
"""Pools photon from all objects in `nbatch` batches.
Photons from faint objects only appear in one of the batches randomly.
"""

def setup(self, config, base, image_num, obj_num, ignore, logger):
# Check we're using the correct stamp type before calling base setup method.
if base['stamp']['type'] != 'LSST_Photons':
raise GalSimConfigValueError("Must use stamp.type = LSST_Photons with LSST_PhotonPoolingImage.", base['stamp']['type'])
return super().setup(config, base, image_num, obj_num, ignore, logger)


def buildImage(self, config, base, image_num, _obj_num, logger):
"""Build the Image.
In contrast to LSST_ImageBuilder, fluxes of all objects are precomputed
before rendering to determine how each object is rendered (FFT / photon shooting).
FFT objects will be handled before the photon shooting objects.
Batching is done over objects for FFT objects and over photons for
all photon shooting objects.
Parameters:
config: The configuration dict for the image field.
base: The base configuration dict.
image_num: The current image number.
obj_num: The first object number in the image.
logger: A logger object to log progress.
Returns:
the final image and the current noise variance in the image as a tuple
"""
set_config_image_pos(config, base)

# For cases where there is noise in individual stamps, we need to keep track of the
# stamp bounds and their current variances. When checkpointing, we don't need to
# save the pixel values for this, just the bounds and the current_var value of each.
all_stamps = []
all_vars = []
all_obj_nums = []
current_photon_batch_num = 0

full_image = None

if self.checkpoint is not None:
chk_name = "buildImage_photonpooling_" + self.det_name
full_image, all_vars, all_stamps, all_obj_nums, current_photon_batch_num = load_checkpoint(self.checkpoint, chk_name, base, logger)
remaining_obj_nums = sorted(frozenset(range(self.nobjects)) - frozenset(all_obj_nums))

if full_image is None:
full_image = create_full_image(config, base)

sensor = base.get('sensor', None)
rng = galsim.config.GetRNG(config, base, logger, "LSST_Silicon")
if sensor is not None:
sensor.updateRNG(rng)

# Create partitions each containing one of the three classes of object.
fft_objects, phot_objects, faint_objects = partition_objects(load_objects(remaining_obj_nums, config, base, logger))
logger.info("Found %d FFT objects, %d photon shooting objects and %d faint objects", len(fft_objects), len(phot_objects), len(faint_objects))
# Ensure 1 <= nbatch <= len(fft_objects)
nbatch = max(min(self.nbatch_fft, len(fft_objects)), 1)
if self.checkpoint is not None:
if not fft_objects:
logger.warning('All FFT objects already rendered for this image.')
else:
logger.warning("%d objects already rendered", len(all_obj_nums))

# Handle FFT objects first:
for batch_num, batch in enumerate(make_batches(fft_objects, nbatch), start=1):
if nbatch > 1:
logger.warning("Start FFT batch %d/%d with %d objects",
batch_num, nbatch, len(batch))
stamps, current_vars = build_stamps(base, logger, batch)
base['index_key'] = 'image_num'

for stamp_obj, stamp in zip(batch, stamps):
bounds = stamp_bounds(stamp, full_image.bounds)
if bounds is None:
continue
logger.debug('image %d: full bounds = %s', image_num, str(full_image.bounds))
logger.debug('image %d: stamp %d bounds = %s',
image_num, stamp_obj.index, str(stamp.bounds))
logger.debug('image %d: Overlap = %s', image_num, str(bounds))
full_image[bounds] += stamp[bounds]
all_obj_nums.append(stamp_obj.index)

# Note: in typical imsim usage, all current_vars will be 0. So this normally doens't
# add much to the checkpointing data.
nz_var = np.nonzero(current_vars)[0]
all_stamps.extend([stamps[k] for k in nz_var])
all_vars.extend([current_vars[k] for k in nz_var])

if self.checkpoint is not None:
save_checkpoint(self.checkpoint, chk_name, base, full_image, all_stamps, all_vars, all_obj_nums, current_photon_batch_num)
logger.warning('File %d: Completed batch %d, and wrote '
'checkpoint data to %s',
base.get('file_num', 0), batch_num,
self.checkpoint.file_name)

# Ensure 1 <= nbatch <= len(phot_objects) and make batches.
nbatch = max(min(self.nbatch, len(phot_objects)), 1)
phot_batches = make_photon_batches(
config, base, logger, phot_objects, faint_objects, nbatch
)

if current_photon_batch_num > 0:
logger.warning(
"Photon batches [0, %d) / %d already rendered - skipping",
current_photon_batch_num,
nbatch,
)
phot_batches = phot_batches[current_photon_batch_num:]

base["image_pos"] = None
photon_ops_cfg = {"photon_ops": base.get("stamp", {}).get("photon_ops", [])}
photon_ops = galsim.config.BuildPhotonOps(photon_ops_cfg, 'photon_ops', base, logger)
offset_adjustment = calc_offset_adjustment(full_image.bounds)
local_wcs = base["wcs"].local(galsim.position._PositionD(0., 0.))
for batch_num, batch in enumerate(phot_batches, start=current_photon_batch_num):
if not batch:
continue
if nbatch > 1:
logger.warning("Starting photon batch %d/%d.",
batch_num+1, nbatch)

base['index_key'] = 'image_num'
stamps, current_vars = build_stamps(base, logger, batch)
offset_photon_arrays(stamps, offset_adjustment)
photons = merge_photon_arrays(stamps)
for op in photon_ops:
op.applyTo(photons, local_wcs, rng)
# Shift photon positions to be relative to full_image.center
photons.x -= full_image.center.x
photons.y -= full_image.center.y
accumulate_photons(photons, full_image, sensor, full_image.center)

# Note: in typical imsim usage, all current_vars will be 0. So this normally doesn't
# add much to the checkpointing data.
nz_var = np.nonzero(current_vars)[0]
all_vars.extend([current_vars[k] for k in nz_var])

if self.checkpoint is not None:
save_checkpoint(self.checkpoint, chk_name, base, full_image, all_stamps, all_vars, all_obj_nums, batch_num+1)

# Bring the image so far up to a flat noise variance
current_var = galsim.config.FlattenNoiseVariance(
base, full_image, all_stamps, tuple(all_vars), logger)

return full_image, current_var

RegisterImageType('LSST_Image', LSST_ImageBuilder())
RegisterImageType('LSST_PhotonPoolingImage', LSST_PhotonPoolingImageBuilder())
Loading

0 comments on commit 5fc1eec

Please sign in to comment.