Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add serializable attributes to capture critical edge oversampling #250

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 66 additions & 30 deletions refl1d/probe/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

import os
import warnings
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Sequence, Tuple, Union

from bumps.parameter import Parameter, to_dict
Expand Down Expand Up @@ -89,6 +89,13 @@ def make_probe(**kw):
return XrayProbe(**kw)


@dataclass
class OversampledRegion:
Q_start: float
Q_end: float
n: int # number of points


class BaseProbe:
intensity: Parameter
background: Parameter
Expand Down Expand Up @@ -696,6 +703,7 @@ class Probe(BaseProbe):
resolution: Literal["normal", "uniform"] = "uniform"
oversampling: Optional[int] = None
oversampling_seed: int = 1
oversampled_regions: List[OversampledRegion] = field(default_factory=list)
radiation: Literal["neutron", "xray"] = "xray"

polarized = False
Expand Down Expand Up @@ -736,6 +744,7 @@ def __init__(
resolution: Literal["normal", "uniform"] = "normal",
oversampling=None,
oversampling_seed=1,
oversampled_regions: Optional[List[OversampledRegion]] = None,
):
if T is None or L is None:
raise TypeError("T and L required")
Expand Down Expand Up @@ -763,8 +772,10 @@ def __init__(
self.name = name
self.filename = filename
self.resolution = resolution
if oversampling is not None:
self.oversample(oversampling, oversampling_seed)
self.oversampling = oversampling
self.oversampling_seed = oversampling_seed
self.oversampled_regions = oversampled_regions if oversampled_regions is not None else []
self._apply_oversamplings()

def _set_TLR(self, T, dT, L, dL, R, dR, dQ):
# if L is None:
Expand Down Expand Up @@ -975,9 +986,6 @@ def critical_edge(self, substrate=None, surface=None, n=51, delta=0.25):
*delta* is the relative uncertainty in the material density,
which defines the range of values which are calculated.

Note: :meth:`critical_edge` will remove the extra Q calculation
points introduced by :meth:`oversample`.

The $n$ points $Q_i$ are evenly distributed around the critical
edge in $Q_c \pm \delta Q_c$ by varying angle $\theta$ for a
fixed wavelength $< \lambda >$, the average of all wavelengths
Expand All @@ -993,19 +1001,11 @@ def critical_edge(self, substrate=None, surface=None, n=51, delta=0.25):
\lambda_i &= < \lambda > \\
\theta_i &= \sin^{-1}(Q_i \lambda_i / 4 \pi)

If $Q_c$ is imaginary, then $-|Q_c|$ is used instead, so this
routine can be used for reflectivity signals which scan from
back reflectivity to front reflectivity. For completeness,
the angle $\theta = 0$ is added as well.
"""
Q_c = self.Q_c(substrate, surface)
Q = np.linspace(Q_c * (1 - delta), Q_c * (1 + delta), n)
L = np.average(self.L)
T = QL2T(Q=Q, L=L)
T = np.hstack((self.T, T, 0))
L = np.hstack((self.L, [L] * (n + 1)))
# print Q
self._set_calc(T, L)
region = OversampledRegion(Q_start=Q_c * (1 - delta), Q_end=Q_c * (1 + delta), n=n)
self.oversampled_regions.append(region)
self._apply_oversamplings()

def oversample(self, n=20, seed=1):
"""
Expand All @@ -1031,19 +1031,32 @@ def oversample(self, n=20, seed=1):
bias from uniform Q steps. Depending on the problem, a value of
*n* between 20 and 100 should lead to stable values for the convolved
reflectivity.

Note: :meth:`oversample` will remove the extra Q calculation
points introduced by :meth:`critical_edge`.
"""

self.oversampling = n
self.oversampling_seed = seed
self._apply_oversamplings()

def _get_normal_oversampling_points(self, n, seed):
rng = numpy.random.RandomState(seed=seed)
T = rng.normal(self.T[:, None], self.dT[:, None], size=(len(self.dT), n - 1))
L = rng.normal(self.L[:, None], self.dL[:, None], size=(len(self.dL), n - 1))
T = np.hstack((self.T, T.flatten()))
L = np.hstack((self.L, L.flatten()))
return T.flatten(), L.flatten()

def _apply_oversamplings(self):
T_parts, L_parts = [self.T], [self.L]
if self.oversampling is not None:
T, L = self._get_normal_oversampling_points(self.oversampling, self.oversampling_seed)
T_parts.append(T)
L_parts.append(L)
for region in self.oversampled_regions:
Q = np.linspace(region.Q_start, region.Q_end, region.n)
avg_L = np.average(self.L)
T_parts.append(QL2T(Q=Q, L=avg_L))
L_parts.append(np.ones_like(Q) * avg_L)
T = np.hstack(T_parts)
L = np.hstack(L_parts)
self._set_calc(T, L)
self.oversampling = n
self.oversampling_seed = seed


class XrayProbe(Probe):
Expand Down Expand Up @@ -1352,6 +1365,9 @@ class QProbe(BaseProbe):
R: "NDArray"
dR: "NDArray"
resolution: Literal["normal", "uniform"]
oversampling: Optional[int]
oversampling_seed: int
oversampled_regions: List[OversampledRegion]

polarized = False

Expand All @@ -1375,6 +1391,9 @@ def __init__(
back_absorption=1,
back_reflectivity=False,
resolution: Literal["normal", "uniform"] = "normal",
oversampling: Optional[int] = None,
oversampling_seed: int = 1,
oversampled_regions: Optional[List[OversampledRegion]] = None,
):
if not name and filename:
name = os.path.splitext(os.path.basename(filename))[0]
Expand All @@ -1398,6 +1417,10 @@ def __init__(
self.name = name
self.filename = filename
self.resolution = resolution
self.oversampling = oversampling
self.oversampling_seed = oversampling_seed
self.oversampled_regions = oversampled_regions if oversampled_regions is not None else []
self._apply_oversamplings()

@property
def calc_Q(self):
Expand All @@ -1417,19 +1440,32 @@ def scattering_factors(self, material, density):

scattering_factors.__doc__ = Probe.scattering_factors.__doc__

def oversample(self, n=20, seed=1):
def _apply_oversamplings(self):
calc_Q_parts = [self.Q]
if self.oversampling is not None:
calc_Q_parts.append(self._get_normal_oversampling_points(self.oversampling, self.oversampling_seed))
for region in self.oversampled_regions:
calc_Q_parts.append(np.linspace(region.Q_start, region.Q_end, region.n))
calc_Q = np.hstack(calc_Q_parts)
self.calc_Qo = np.sort(calc_Q)

def _get_normal_oversampling_points(self, n, seed):
rng = numpy.random.RandomState(seed=seed)
extra = rng.normal(self.Q, self.dQ, size=(n - 1, len(self.Q)))
calc_Q = np.hstack((self.Q, extra.flatten()))
self.calc_Qo = np.sort(calc_Q)
return extra

def oversample(self, n=20, seed=1):
self.oversampling = n
self.oversampling_seed = seed
self._apply_oversamplings()

oversample.__doc__ = Probe.oversample.__doc__

def critical_edge(self, substrate=None, surface=None, n=51, delta=0.25):
Q_c = self.Q_c(substrate, surface)
extra = np.linspace(Q_c * (1 - delta), Q_c * (1 + delta), n)
calc_Q = np.hstack((self.Q, extra, 0))
self.calc_Qo = np.sort(calc_Q)
region = OversampledRegion(Q_start=Q_c * (1 - delta), Q_end=Q_c * (1 + delta), n=n)
self.oversampled_regions.append(region)
self._apply_oversamplings()

critical_edge.__doc__ = Probe.critical_edge.__doc__

Expand Down
Loading