Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculate centers #461

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
163 changes: 158 additions & 5 deletions imsim/stamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def setup(self, config, base, xsize, ysize, ignore, logger):

req = {}
opt = {'camera': str, 'diffraction_fft': dict,
'airmass': float, 'rawSeeing': float, 'band': str}
'airmass': float, 'rawSeeing': float, 'band': str,
'centroid': dict}
if self.vignetting:
req['det_name'] = str
else:
Expand Down Expand Up @@ -247,6 +248,9 @@ def setup(self, config, base, xsize, ysize, ignore, logger):
else:
world_pos = None

if 'centroid' in params and 'n_photons' not in params['centroid']:
raise ValueError('n_photons not found in centroid config')

return xsize, ysize, image_pos, world_pos

def _getGoodPhotImageSize(self, obj_list, keep_sb_level, pixel_scale):
Expand Down Expand Up @@ -758,9 +762,10 @@ def draw(self, prof, image, method, offset, config, base, logger):
else:
bp_for_drawImage = bandpass

# Put the psfs at the start of the photon_ops.
# Probably a little better to put them a bit later than the start in some cases
# (e.g. after TimeSampler, PupilAnnulusSampler), but leave that as a todo for now.
# Put the psfs at the start of the photon_ops. Probably a little
# better to put them a bit later than the start in some cases (e.g.
# after TimeSampler, PupilAnnulusSampler), but leave that as a todo
# for now.
photon_ops = psfs + photon_ops

if faint:
Expand All @@ -783,10 +788,158 @@ def draw(self, prof, image, method, offset, config, base, logger):
poisson_flux=False)
base['realized_flux'] = image.added_flux

if 'centroid' in config:
xvals, yvals = _get_photon_positions(
gal=gal,
rng=self.rng,
bp_for_drawImage=bp_for_drawImage,
image=image,
sensor=sensor,
photon_ops=photon_ops,
offset=offset,
config=config['centroid'],
)
cenres = _get_robust_centroids(xvals, yvals)

# these can be saved in the config using @xcentroid, @ycentroid
base['xcentroid'] = cenres['x']
base['xcentroid_err'] = cenres['xerr']
base['ycentroid'] = cenres['y']
base['ycentroid_err'] = cenres['yerr']

return image


def _get_photon_positions(
gal,
rng,
bp_for_drawImage,
image,

sensor,
photon_ops,
offset,
config,
):
"""
draw another image with fixed number of photons
and use to get centroid. We can't do this above
because with maxN the photons cannot be saved

100_000 photons should give centroid to a few miliarcsec
"""

timage = image.copy()
gal.drawImage(
bp_for_drawImage,
image=timage,

method='phot',
n_photons=config['n_photons'],
sensor=sensor,
photon_ops=photon_ops,
poisson_flux=False,
save_photons=True,

offset=offset,
rng=rng,
)

photx = timage.photons.x
photy = timage.photons.y

flux = timage.photons.flux
wvig, = np.where(flux == 0)

logic = np.isnan(photx) | np.isnan(photy)
wnan, = np.where(logic)

# flux == 0 means it was vignetted
wgood, = np.where(
np.isfinite(photx)
& np.isfinite(photy)
& (timage.photons.flux > 0)
)

# location of true center, actually in big image
imcen = image.true_center
xvals = imcen.x + photx[wgood]
yvals = imcen.y + photy[wgood]

return xvals, yvals


def _get_robust_centroids(xvals, yvals):
xcen, _, xerr = sigma_clip(xvals)
ycen, _, yerr = sigma_clip(yvals)

return {
'x': xcen,
'xerr': xerr,
'y': ycen,
'yerr': yerr,
}


def sigma_clip(arrin, niter=4, nsig=4):
"""
Calculate the mean, sigma, error of an array with sigma clipping.

parameters
----------
arr: array or sequence
A numpy array or sequence
niter: int, optional
number of iterations, defaults to 4
nsig: float, optional
number of sigma, defaults to 4

returns
-------
mean, stdev, err
"""

arr = np.array(arrin, ndmin=1, copy=False)

if len(arr.shape) > 1:
raise ValueError(
'only 1-dimensional arrays suppored, got {arr.shape}'
)

indices = np.arange(arr.size)
nold = arr.size

mn, sig, err = _get_sigma_clip_stats(arr)

for i in range(1, niter + 1):

w, = np.where((np.abs(arr[indices] - mn)) < nsig * sig)

if w.size == 0:
# everything clipped, nothing to do but report latest
# statistics
break

if w.size == nold:
break

indices = indices[w]
nold = w.size

mn, sig, err = _get_sigma_clip_stats(arr[indices])

return mn, sig, err


def _get_sigma_clip_stats(arr):
mn = arr.mean()
sig = arr.std()
err = sig / np.sqrt(arr.size)
return mn, sig, err


# Pick the right function to be _fix_seds.
if galsim.__version_info__ < (2,5):
if galsim.__version_info__ < (2, 5):
LSST_SiliconBuilder._fix_seds = LSST_SiliconBuilder._fix_seds_24
else:
LSST_SiliconBuilder._fix_seds = LSST_SiliconBuilder._fix_seds_25
Expand Down
44 changes: 43 additions & 1 deletion tests/test_stamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,49 @@ def new_toplevel_only(self, object_types):
np.testing.assert_allclose(realized_X, predict, rtol=rtol)


def test_centroid_smoke() -> None:
"""
Test that centroids are calculated.

We are not checking accuracy here
"""
lsst_silicon = create_test_lsst_silicon(faint=False)
image = galsim.Image(ncol=256, nrow=256)
offset = galsim.PositionD(0, 0)
config = create_test_config()
config['stamp']['centroid'] = {
'n_photons': 1000,
}

# order matters here
config['stamp']['photon_ops'] = [
{'type': 'TimeSampler', 't0': 0.0, 'exptime': 30.0},
{'type': 'PupilAnnulusSampler', 'R_outer': 4.18, 'R_inner': 2.55},
] + config['stamp']['photon_ops']

prof = galsim.Gaussian(sigma=2) * galsim.SED('vega.txt', 'nm', 'flambda')
method = 'phot'
logger = galsim.config.LoggerWrapper(None)

image.added_flux = lsst_silicon.phot_flux
lsst_silicon.draw(
prof,
image,
method,
offset,
config=config["stamp"],
base=config,
logger=logger,
)
for c in ('x', 'y'):
for n in ('', '_err'):
name = f'{c}centroid{n}'
assert name in config


if __name__ == "__main__":
testfns = [v for k, v in vars().items() if k[:5] == 'test_' and callable(v)]
testfns = [
v for k, v in vars().items() if k[:5] == 'test_' and callable(v)
]
for testfn in testfns:
testfn()
Loading