Skip to content

Commit

Permalink
Allow transformed groups to be flattened (#2050)
Browse files Browse the repository at this point in the history
Related to:

- flexcompute/tidy3d-core#751

- flexcompute/tidy3d-core#750

Signed-off-by: Lucas Heitzmann Gabrielli <[email protected]>
  • Loading branch information
lucas-flexcompute authored Nov 1, 2024
1 parent 164b4b8 commit 76ab553
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Autograd support for local field projections using `FieldProjectionKSpaceMonitor`.
- Function `components.geometry.utils.flatten_groups` now also flattens transformed groups when requested.

### Fixed
- Regression in local field projection leading to incorrect results for `far_field_approx=True`.
Expand Down
27 changes: 27 additions & 0 deletions tests/test_components/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,33 @@ def test_flattening():
for g in flat
)

t0 = np.array([[2, 0, 0, 0], [3, 2, 0, 0], [1, 0, 2, 0], [0, 0, 0, 1.0]])
g0 = td.Sphere(radius=1)
t1 = np.array([[2, 0, 5, 0], [0, 1, 0, 0], [-1, 0, 1, 0], [0, 0, 0, 1.0]])
g1 = td.Box(size=(1, 2, 3))
flat = list(
flatten_groups(
td.Transformed(
transform=t0,
geometry=td.ClipOperation(
operation="union",
geometry_a=g0,
geometry_b=td.Transformed(transform=t1, geometry=g1),
),
),
flatten_transformed=True,
)
)
assert len(flat) == 2

assert isinstance(flat[0], td.Transformed)
assert flat[0].geometry == g0
assert np.allclose(flat[0].transform, t0)

assert isinstance(flat[1], td.Transformed)
assert flat[1].geometry == g1
assert np.allclose(flat[1].transform, t0 @ t1)


def test_geometry_traversal():
geometries = list(traverse_geometries(td.Box(size=(1, 1, 1))))
Expand Down
33 changes: 29 additions & 4 deletions tidy3d/components/geometry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from math import isclose
from typing import Tuple, Union
from typing import Optional, Tuple, Union

import numpy as np

Expand All @@ -24,17 +24,25 @@
]


def flatten_groups(*geometries: GeometryType, flatten_nonunion_type: bool = False) -> GeometryType:
def flatten_groups(
*geometries: GeometryType,
flatten_nonunion_type: bool = False,
flatten_transformed: bool = False,
transform: Optional[MatrixReal4x4] = None,
) -> GeometryType:
"""Iterates over all geometries, flattening groups and unions.
Parameters
----------
*geometries : GeometryType
Geometries to flatten.
flatten_nonunion_type : bool = False
If ``False``, only flatten geometry unions (and ``GeometryGroup``). If ``True``, flatten
all clip operations.
flatten_transformed : bool = False
If ``True``, ``Transformed`` groups are flattened into individual transformed geometries.
transform : Optional[MatrixReal4x4]
Accumulated transform from parents. Only used when ``flatten_transformed`` is ``True``.
Yields
------
Expand All @@ -44,7 +52,10 @@ def flatten_groups(*geometries: GeometryType, flatten_nonunion_type: bool = Fals
for geometry in geometries:
if isinstance(geometry, base.GeometryGroup):
yield from flatten_groups(
*geometry.geometries, flatten_nonunion_type=flatten_nonunion_type
*geometry.geometries,
flatten_nonunion_type=flatten_nonunion_type,
flatten_transformed=flatten_transformed,
transform=transform,
)
elif isinstance(geometry, base.ClipOperation) and (
flatten_nonunion_type or geometry.operation == "union"
Expand All @@ -53,7 +64,21 @@ def flatten_groups(*geometries: GeometryType, flatten_nonunion_type: bool = Fals
geometry.geometry_a,
geometry.geometry_b,
flatten_nonunion_type=flatten_nonunion_type,
flatten_transformed=flatten_transformed,
transform=transform,
)
elif flatten_transformed and isinstance(geometry, base.Transformed):
new_transform = geometry.transform
if transform is not None:
new_transform = np.matmul(transform, new_transform)
yield from flatten_groups(
geometry.geometry,
flatten_nonunion_type=flatten_nonunion_type,
flatten_transformed=flatten_transformed,
transform=new_transform,
)
elif flatten_transformed and transform is not None:
yield base.Transformed(geometry=geometry, transform=transform)
else:
yield geometry

Expand Down
2 changes: 1 addition & 1 deletion tidy3d/components/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _validate_num_geometries(cls, val):
return val

for i, structure in enumerate(val):
for geometry in flatten_groups(structure.geometry):
for geometry in flatten_groups(structure.geometry, flatten_transformed=True):
count = sum(
1
for g in traverse_geometries(geometry)
Expand Down

0 comments on commit 76ab553

Please sign in to comment.