Skip to content

Commit 2f2657e

Browse files
committed
warnings and more edge case handling
1 parent 2938271 commit 2f2657e

File tree

1 file changed

+59
-7
lines changed

1 file changed

+59
-7
lines changed

tidy3d/plugins/autograd/invdes/parametrizations.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections import deque
4-
from typing import Callable, Optional, Union
4+
from typing import Callable, Literal, Optional, Union
55

66
import autograd.numpy as np
77
import pydantic.v1 as pd
@@ -110,6 +110,7 @@ def initialize_params_from_simulation(
110110
params0: np.ndarray,
111111
*,
112112
freq: Optional[float] = None,
113+
outside_handling: Literal["extrapolate", "mask", "nan"] = "mask",
113114
maxiter: int = 100,
114115
bounds: tuple[Optional[float], Optional[float]] = (0.0, 1.0),
115116
rel_improve_tol: float = 1e-3,
@@ -133,6 +134,10 @@ def initialize_params_from_simulation(
133134
using coordinate-aware interpolation (no per-iteration interpolation).
134135
- Early stopping uses a single knob ``rel_improve_tol``; optimization stops once the
135136
relative improvement over a small fixed window falls below this value.
137+
- Points outside the base‑epsilon coverage can be handled via ``outside_handling``:
138+
``'extrapolate'`` (use nearest extrapolation, default in earlier versions),
139+
``'mask'`` (ignore outside points using a coverage mask; default), or
140+
``'nan'`` (non‑extended grid; treat outside as NaN and ignore in the loss).
136141
137142
Parameters
138143
----------
@@ -148,6 +153,12 @@ def initialize_params_from_simulation(
148153
Maximum number of L‑BFGS‑B iterations.
149154
freq : float, optional
150155
Frequency at which permittivity is evaluated. If ``None``, uses infinite frequency.
156+
outside_handling : {"extrapolate", "mask", "nan"} = "mask"
157+
Strategy for points where design coordinates fall outside the sampled base epsilon:
158+
- "extrapolate": include them using nearest extrapolation.
159+
- "mask": include only points within the coverage bounds of ``sim.epsilon`` on the
160+
extended subgrid.
161+
- "nan": sample on a non‑extended subgrid and ignore points where interpolation returns NaN.
151162
bounds : tuple[float | None, float | None] = (0.0, 1.0)
152163
Element‑wise parameter bounds, e.g. ``(0.0, 1.0)`` or ``(-1.0, 1.0)``. Use ``None`` to
153164
indicate an unbounded side.
@@ -193,7 +204,13 @@ def initialize_params_from_simulation(
193204
(3, 3)
194205
"""
195206
structure_init = param_to_structure(params0)
196-
subgrid = sim.discretize(structure_init.geometry, extend=True)
207+
208+
if outside_handling not in ("extrapolate", "mask", "nan"):
209+
raise ValueError("'outside_handling' must be one of {'extrapolate', 'mask', 'nan'}.")
210+
211+
# build base epsilon and interpolate onto design coords
212+
extend_flag = outside_handling != "nan"
213+
subgrid = sim.discretize(structure_init.geometry, extend=extend_flag)
197214
eps_base_da = sim.epsilon_on_grid(grid=subgrid, coord_key="centers", freq=freq)
198215

199216
design_eps_da = structure_init.medium.permittivity
@@ -202,18 +219,53 @@ def initialize_params_from_simulation(
202219
y=np.array(design_eps_da.coords["y"]),
203220
z=np.array(design_eps_da.coords["z"]),
204221
)
205-
eps_base_interp = design_coords.spatial_interp(
206-
array=eps_base_da, interp_method="linear", fill_value="extrapolate"
207-
).data
208222

209-
denom = np.sqrt(np.sum(eps_base_interp.real**2 + eps_base_interp.imag**2))
223+
if outside_handling == "nan":
224+
eps_base_interp = design_coords.spatial_interp(
225+
array=eps_base_da, interp_method="linear", fill_value=np.nan
226+
).data
227+
mask = np.isfinite(eps_base_interp)
228+
else:
229+
eps_base_interp = design_coords.spatial_interp(
230+
array=eps_base_da, interp_method="linear", fill_value="extrapolate"
231+
).data
232+
if outside_handling == "mask":
233+
# build mask from coverage bounds of base epsilon coordinates
234+
xb, yb, zb = [np.array(eps_base_da.coords[d]) for d in ("x", "y", "z")]
235+
xd, yd, zd = [np.array(design_eps_da.coords[d]) for d in ("x", "y", "z")]
236+
mask_x = (xd >= np.min(xb)) & (xd <= np.max(xb))
237+
mask_y = (yd >= np.min(yb)) & (yd <= np.max(yb))
238+
mask_z = (zd >= np.min(zb)) & (zd <= np.max(zb))
239+
mask = (mask_x[:, None, None]) & (mask_y[None, :, None]) & (mask_z[None, None, :])
240+
else:
241+
mask = None
242+
243+
if mask is not None:
244+
covered = int(np.sum(mask))
245+
total = int(np.prod(mask.shape))
246+
frac = covered / max(total, 1)
247+
if frac < 0.9:
248+
td.log.warning(
249+
f"Only {frac:.1%} of design points are covered by base epsilon sampling. "
250+
"Consider adding a 'MeshOverrideStructure' or adjusting design coordinates."
251+
)
252+
253+
if mask is None:
254+
denom = np.sqrt(np.sum(eps_base_interp.real**2 + eps_base_interp.imag**2))
255+
else:
256+
denom = np.sqrt(
257+
np.sum((eps_base_interp.real[mask]) ** 2 + (eps_base_interp.imag[mask]) ** 2)
258+
)
210259
denom = 1.0 if denom == 0 else denom
211260

212261
def loss_fn(params_vec: np.ndarray) -> float:
213262
params = params_vec.reshape(params0.shape)
214263
structure = param_to_structure(params)
215264
eps_design = structure.medium.permittivity.data
216-
res = eps_base_interp - eps_design
265+
if mask is None:
266+
res = eps_base_interp - eps_design
267+
return 0.5 * np.sum(res.real**2 + res.imag**2) / denom
268+
res = (eps_base_interp - eps_design)[mask]
217269
return 0.5 * np.sum(res.real**2 + res.imag**2) / denom
218270

219271
val_and_grad = value_and_grad(loss_fn)

0 commit comments

Comments
 (0)