Skip to content

Commit

Permalink
Merge pull request #2048 from kif/notebook-calibration
Browse files Browse the repository at this point in the history
improvements for the calibration from the notebook
  • Loading branch information
kif authored Jan 22, 2024
2 parents a85733f + c0504c7 commit 89d2436
Showing 1 changed file with 30 additions and 20 deletions.
50 changes: 30 additions & 20 deletions src/pyFAI/gui/cli_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
__contact__ = "[email protected]"
__license__ = "MIT"
__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
__date__ = "20/04/2022"
__date__ = "19/01/2024"
__status__ = "production"

import os
Expand Down Expand Up @@ -561,14 +561,16 @@ def preprocess(self):
else:
self.peakPicker.massif.init_valley_size()

def extract_cpt(self, method="massif", pts_per_deg=1.0):
def extract_cpt(self, method="massif", pts_per_deg=1.0, max_rings=numpy.iinfo(int).max):
"""
Performs an automatic keypoint extraction:
Can be used in recalib or in calib after a first calibration has been performed.
:param method: method for keypoint extraction
:param pts_per_deg: number of control points per azimuthal degree (increase for better precision)
:param max_rings: extract at most max_rings
"""

logger.info("in extract_cpt with method %s", method)
assert self.ai
assert self.calibrant
Expand Down Expand Up @@ -605,13 +607,13 @@ def extract_cpt(self, method="massif", pts_per_deg=1.0):
self.max_rings = tth.size

ms = marchingsquares.MarchingSquaresMergeImpl(ttha, self.mask, use_minmax_cache=True)

for i in range(tth.size):
if rings >= self.max_rings:
if rings >= min(self.max_rings, max_rings):
break
mask = numpy.logical_and(ttha >= tth_min[i], ttha < tth_max[i])
if self.mask is not None:
mask = numpy.logical_and(mask, numpy.logical_not(self.mask))

size = mask.sum(dtype=int)
if (size > 0):
rings += 1
Expand Down Expand Up @@ -1350,31 +1352,39 @@ def reset_geometry(self, how="center", refine=False):
self.geoRef.set_rot3_max(math.pi)
self.geoRef.set_rot3(self.ai.rot3)

def initgeoRef(self):
def initgeoRef(self, defaults=None):
"""
Tries to initialise the GeometryRefinement (dist, poni, rot)
Returns a dictionary of key value pairs
:param: default parameters as a dict to be passed to constructor of GeometryRefinement
:return: initialized geometry refinement
"""
defaults = {"dist": 0.1, "poni1": 0.0, "poni2": 0.0,
"rot1": 0.0, "rot2": 0.0, "rot3": 0.0}
if defaults is None:
defaults = {"dist": 0.1, "poni1": 0.0, "poni2": 0.0,
"rot1": 0.0, "rot2": 0.0, "rot3": 0.0}
else:
defaults = defaults.copy()
if self.detector:
try:
p1, p2, _p3 = self.detector.calc_cartesian_positions()
defaults["poni1"] = p1.max() / 2.
defaults["poni2"] = p2.max() / 2.
except Exception as err:
logger.warning(err)
if not (defaults.get("poni1") or defaults.get("poni2")):
try:
p1, p2, _p3 = self.detector.calc_cartesian_positions()
defaults["poni1"] = p1.max() / 2.
defaults["poni2"] = p2.max() / 2.
except Exception as err:
logger.warning(err)
defaults["detector"] = self.detector
if self.ai:
for key in defaults.keys(): # not PARAMETERS which holds wavelength
val = getattr(self.ai, key, None)
if val is not None:
defaults[key] = val

georef = GeometryRefinement(self.data,
detector=self.detector,
wavelength=self.wavelength,
calibrant=self.calibrant,
**defaults)
if self.wavelength:
defaults["wavelength"] = self.wavelength
if self.calibrant:
defaults["calibrant"] = self.calibrant
if len(self.data):
defaults["data"] = self.data
georef = GeometryRefinement(**defaults)
return georef


Expand Down

0 comments on commit 89d2436

Please sign in to comment.