Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diffractive detector & Multi frequency phasor #19

Merged
merged 2 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions src/fdtdx/objects/detectors/diffractive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from typing import Literal, Sequence, Tuple

import jax
import jax.numpy as jnp
import pytreeclass as tc

from fdtdx.core.physics import constants
from fdtdx.objects.detectors.detector import Detector, DetectorState


@tc.autoinit
class DiffractiveDetector(Detector):
"""Detector for computing Fourier transforms of fields at specific frequencies and diffraction orders.

This detector computes field amplitudes for specific diffraction orders and frequencies through
a specified plane in the simulation volume. It can measure diffraction in either positive or negative
direction along the propagation axis.

Attributes:
frequencies: List of frequencies to analyze (in Hz)
orders: Tuple of (nx, ny) pairs specifying diffraction orders to compute
direction: Direction of diffraction analysis ("+" or "-") along propagation axis
"""

frequencies: Sequence[float] = 0.0
orders: Sequence[Tuple[int, int]] = ((0, 0),)
direction: Literal["+", "-"] = tc.field( # type: ignore
init=True,
kind="KW_ONLY",
on_getattr=[tc.unfreeze],
on_setattr=[tc.freeze],
)
dtype: jnp.dtype = tc.field(
default=jnp.complex64,
kind="KW_ONLY",
)

def __post_init__(self):
if self.dtype not in [jnp.complex64, jnp.complex128]:
raise Exception(f"Invalid dtype in DiffractiveDetector: {self.dtype}")

@property
def propagation_axis(self) -> int:
"""Determines the axis along which diffraction is measured.

The propagation axis is identified as the dimension with size 1 in the
detector's grid shape, representing a plane perpendicular to the diffraction
measurement direction.

Returns:
int: Index of the propagation axis (0 for x, 1 for y, 2 for z)

Raises:
Exception: If detector shape does not have exactly one dimension of size 1
"""
if sum([a == 1 for a in self.grid_shape]) != 1:
raise Exception(f"Invalid diffractive detector shape: {self.grid_shape}")
return self.grid_shape.index(1)

def _validate_orders(self, wavelength: float) -> None:
"""Validate that requested diffraction orders are physically realizable.

Args:
wavelength: Wavelength of the light in meters

Raises:
Exception: If any requested order is not physically realizable
"""
if self._Nx is None:
raise Exception("Order info not yet computed. Run update first.")

# Maximum possible orders based on grid
max_nx = self._Nx // 2
max_ny = self._Ny // 2

# Check Nyquist limits for all orders at once
nx_valid = jnp.all(jnp.abs(jnp.array([o[0] for o in self.orders])) <= max_nx)
ny_valid = jnp.all(jnp.abs(jnp.array([o[1] for o in self.orders])) <= max_ny)

if not (nx_valid and ny_valid):
raise Exception(f"Some orders exceed Nyquist limit for grid size ({self._Nx}, {self._Ny})")

# Check physical realizability for all orders at once
k0 = 2 * jnp.pi / wavelength
kt_squared = self._kx_normalized**2 + self._ky_normalized**2

if jnp.any(kt_squared > k0**2):
raise Exception(f"Some orders are evanescent at wavelength {wavelength*1e9:.1f}nm")

def _shape_dtype_single_time_step(self) -> dict[str, jax.ShapeDtypeStruct]:
"""Define shape and dtype for a single time step of diffractive data.

Returns:
dict: Dictionary mapping data keys to ShapeDtypeStruct containing shape and
dtype information for each frequency and order combination.
"""
num_freqs = len(self.frequencies)
num_orders = len(self.orders)

shape = (num_freqs, num_orders)

# Ensure we're using a complex dtype
field_dtype = jnp.complex128 if self.dtype == jnp.float64 else jnp.complex64
return {"diffractive": jax.ShapeDtypeStruct(shape=shape, dtype=field_dtype)}

def _num_latent_time_steps(self) -> int:
"""Get number of time steps needed for latent computation.

Returns:
int: Always returns 1 for diffractive detector since only current state is needed.
"""
return 1

def update(
self,
time_step: jax.Array,
E: jax.Array,
H: jax.Array,
state: DetectorState,
inv_permittivity: jax.Array,
inv_permeability: jax.Array,
) -> DetectorState:
"""Update the diffractive detector state with current field values."""
del inv_permittivity, inv_permeability

# Get grid dimensions for the plane perpendicular to propagation axis
prop_axis = self.propagation_axis
plane_dims = [i for i in range(3) if i != prop_axis]
Nx, Ny = [self.grid_shape[i] for i in plane_dims]

# Get current field values at the specified plane
cur_E = E[:, *self.grid_slice] # Shape: (3, nx, ny, 1)
cur_H = H[:, *self.grid_slice] # Shape: (3, nx, ny, 1)

# Remove the normal axis dimension since it should be 1
cur_E = jnp.squeeze(cur_E, axis=prop_axis + 1) # Shape: (3, nx, ny)
cur_H = jnp.squeeze(cur_H, axis=prop_axis + 1) # Shape: (3, nx, ny)

# Compute FFT of each field component
E_k = jnp.fft.fft2(cur_E, axes=tuple(d + 1 for d in plane_dims)) # FFT in spatial dimensions
H_k = jnp.fft.fft2(cur_H, axes=tuple(d + 1 for d in plane_dims))

# Convert orders to array for vectorization
orders = jnp.array(self.orders) # Shape: (num_orders, 2)

# Compute FFT indices for all orders
kx_indices = jnp.where(orders[:, 0] >= 0, orders[:, 0], Nx + orders[:, 0])
ky_indices = jnp.where(orders[:, 1] >= 0, orders[:, 1], Ny + orders[:, 1])

# Compute wavevectors
dx = dy = self._config.resolution
kx = 2 * jnp.pi * jnp.fft.fftfreq(Nx, dx)
ky = 2 * jnp.pi * jnp.fft.fftfreq(Ny, dy)
k0 = 2 * jnp.pi * self.frequencies[0] / constants.c # Use first frequency for now

# For each requested order, compute the diffracted power
order_amplitudes = []
for kx_idx, ky_idx in zip(kx_indices, ky_indices):
# Get the field components for this k-point
E_order = E_k[:, kx_idx, ky_idx]
H_order = H_k[:, kx_idx, ky_idx]

# Compute kz for propagating waves
kz = jnp.sqrt(k0**2 - kx[kx_idx] ** 2 - ky[ky_idx] ** 2 + 0j)
k_vec = jnp.array([kx[kx_idx], ky[ky_idx], kz])

# Project fields to be transverse to k
E_t = E_order - jnp.dot(E_order, k_vec) * k_vec / jnp.dot(k_vec, k_vec)
H_t = H_order - jnp.dot(H_order, k_vec) * k_vec / jnp.dot(k_vec, k_vec)

# Compute power in this order
P_order = jnp.abs(jnp.cross(E_t, jnp.conj(H_t)).sum())
if self.direction == "-":
P_order = -P_order
order_amplitudes.append(P_order)

order_amplitudes = jnp.array(order_amplitudes)

# Time domain analysis - vectorized for all frequencies
t = time_step * self._config.time_step_duration
angular_frequencies = 2 * jnp.pi * jnp.array(self.frequencies)
phase_angles = angular_frequencies[:, None] * t # Shape: (num_freqs, 1)
phasors = jnp.exp(-1j * phase_angles) # Shape: (num_freqs, 1)

# Compute all frequency components for all orders at once
order_amplitudes = order_amplitudes[None, :] # Shape: (1, num_orders)
new_values = order_amplitudes * phasors # Shape: (num_freqs, num_orders)

# Update state
arr_idx = self._time_step_to_arr_idx[time_step]
new_state = state.copy()
new_state["diffractive"] = new_state["diffractive"].at[arr_idx].set(new_values)

return new_state
61 changes: 24 additions & 37 deletions src/fdtdx/objects/detectors/phasor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,27 @@
import jax.numpy as jnp

from fdtdx.core.jax.pytrees import extended_autoinit, frozen_field
from fdtdx.core.physics import constants
from fdtdx.objects.detectors.detector import Detector, DetectorState


@extended_autoinit
class PhasorDetector(Detector):
"""Detector for measuring phasor components of electromagnetic fields.

This detector computes complex phasor representations of the field components,
enabling frequency-domain analysis of the electromagnetic fields.
This detector computes complex phasor representations of the field components at specified
frequencies, enabling frequency-domain analysis of the electromagnetic fields.

Attributes:
frequencies: Sequence of frequencies to analyze (in Hz)
as_slices: If True, returns results as slices rather than full volume.
reduce_volume: If True, reduces the volume of recorded data.
wavelength: Wavelength of the phasor analysis in meters. Either this or
period_length must be specified.
components: Sequence of field components to measure. Can include any of:
"Ex", "Ey", "Ez", "Hx", "Hy", "Hz".
"""

frequencies: Sequence[float] = (None,)
as_slices: bool = False
reduce_volume: bool = False
wavelength: float | None = None
components: Sequence[Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]] = frozen_field(
default=("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"),
)
Expand All @@ -41,24 +39,9 @@ def __post_init__(
if self.dtype not in [jnp.complex64, jnp.complex128]:
raise Exception(f"Invalid dtype in PhasorDetector: {self.dtype}")

@property
def frequency(self) -> float:
"""Calculate the frequency for phasor analysis.

Returns:
float: The frequency in Hz calculated from either wavelength or period_length.

Raises:
Exception: If neither wavelength nor period_length is specified.
"""
if self.period_length is None and self.wavelength is None:
raise Exception("Specify either wavelength or period_length for PhasorDetector")
p = self.period_length
if p is None:
if self.wavelength is None:
raise Exception("this should never happen")
p = self.wavelength / constants.c
return 1 / p
# Precompute angular frequencies for vectorization
self._angular_frequencies = 2 * jnp.pi * jnp.array(self.frequencies)
self._scale = self._config.time_step_duration / jnp.sqrt(2 * jnp.pi)

def _num_latent_time_steps(self) -> int:
"""Get number of time steps needed for latent computation.
Expand All @@ -75,12 +58,13 @@ def _shape_dtype_single_time_step(

Returns:
dict: Dictionary with 'phasor' key mapping to a ShapeDtypeStruct containing:
- shape: (num_components, *grid_shape)
- shape: (num_frequencies, num_components, *grid_shape)
- dtype: Complex64 or Complex128 depending on detector's dtype
"""
field_dtype = jnp.complex128 if self.dtype == jnp.float64 else jnp.complex64
num_components = len(self.components)
phasor_shape = (num_components, *self.grid_shape)
num_frequencies = len(self.frequencies)
phasor_shape = (num_frequencies, num_components, *self.grid_shape)
return {"phasor": jax.ShapeDtypeStruct(shape=phasor_shape, dtype=field_dtype)}

def update(
Expand All @@ -94,8 +78,8 @@ def update(
) -> DetectorState:
"""Update the phasor state with current field values.

Computes the phasor representation by multiplying field components with a complex
exponential at the detector's frequency.
Computes the phasor representation by multiplying field components with complex
exponentials at each of the detector's frequencies.

Args:
time_step: Current simulation time step
Expand All @@ -109,12 +93,7 @@ def update(
DetectorState: Updated state containing new phasor values
"""
del inv_permeability, inv_permittivity
delta_t = self._config.time_step_duration
scale = delta_t / jnp.sqrt(2 * jnp.pi)
angular_frequency = 2 * jnp.pi * self.frequency
time_passed = time_step * delta_t
phase_angle = angular_frequency * time_passed
phasor = jnp.exp(1j * phase_angle)
time_passed = time_step * self._config.time_step_duration

E, H = E[:, *self.grid_slice], H[:, *self.grid_slice]
fields = []
Expand All @@ -133,10 +112,18 @@ def update(

EH = jnp.stack(fields, axis=0)

new_phasor = EH * phasor * scale
# Vectorized phasor calculation for all frequencies
phase_angles = self._angular_frequencies[:, None] * time_passed # Shape: (num_freqs, 1)
phasors = jnp.exp(1j * phase_angles) # Shape: (num_freqs, 1)
new_phasors = EH[None, ...] * phasors[..., None] * self._scale # Broadcasting handles the multiplication

if self.reduce_volume:
# Average over all spatial dimensions
spatial_axes = tuple(range(2, new_phasors.ndim)) # Skip freq and component axes
new_phasors = new_phasors.mean(axis=spatial_axes) if spatial_axes else new_phasors

if self.inverse:
result = state["phasor"] - new_phasor[None, ...]
result = state["phasor"] - new_phasors[None, ...]
else:
result = state["phasor"] + new_phasor[None, ...]
result = state["phasor"] + new_phasors[None, ...]
return {"phasor": result.astype(self.dtype)}