Skip to content

Commit 2938271

Browse files
committed
fix interpolation
1 parent fe962a6 commit 2938271

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

tidy3d/plugins/autograd/invdes/parametrizations.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import pydantic.v1 as pd
88
from autograd import value_and_grad
99
from numpy.typing import NDArray
10-
from scipy.ndimage import zoom
1110
from scipy.optimize import minimize
1211

1312
import tidy3d as td
1413
from tidy3d.components.base import Tidy3dBaseModel
14+
from tidy3d.components.grid.grid import Coords
1515
from tidy3d.plugins.autograd.constants import BETA_DEFAULT, ETA_DEFAULT
1616
from tidy3d.plugins.autograd.types import KernelType, PaddingType
1717

@@ -109,6 +109,7 @@ def initialize_params_from_simulation(
109109
param_to_structure: Callable[[np.ndarray], td.Structure],
110110
params0: np.ndarray,
111111
*,
112+
freq: Optional[float] = None,
112113
maxiter: int = 100,
113114
bounds: tuple[Optional[float], Optional[float]] = (0.0, 1.0),
114115
rel_improve_tol: float = 1e-3,
@@ -127,6 +128,9 @@ def initialize_params_from_simulation(
127128
- The callable ``param_to_structure`` controls the parameterization and must return
128129
a :class:`.Structure` (typically with a :class:`.CustomMedium`). Its permittivity
129130
dataset (and coordinates) defines the grid used for comparison.
131+
- The base simulation permittivity is sampled once on an extended subgrid covering
132+
the design geometry and then interpolated onto the design region coordinates
133+
using coordinate-aware interpolation (no per-iteration interpolation).
130134
- Early stopping uses a single knob ``rel_improve_tol``; optimization stops once the
131135
relative improvement over a small fixed window falls below this value.
132136
@@ -188,12 +192,19 @@ def initialize_params_from_simulation(
188192
>>> params.shape
189193
(3, 3)
190194
"""
191-
192195
structure_init = param_to_structure(params0)
193-
eps_base = sim.epsilon(box=structure_init.geometry).data
194-
eps_param = structure_init.medium.permittivity.data
195-
factors = np.array(eps_param.shape) / np.array(eps_base.shape)
196-
eps_base_interp = zoom(eps_base, factors, order=1)
196+
subgrid = sim.discretize(structure_init.geometry, extend=True)
197+
eps_base_da = sim.epsilon_on_grid(grid=subgrid, coord_key="centers", freq=freq)
198+
199+
design_eps_da = structure_init.medium.permittivity
200+
design_coords = Coords(
201+
x=np.array(design_eps_da.coords["x"]),
202+
y=np.array(design_eps_da.coords["y"]),
203+
z=np.array(design_eps_da.coords["z"]),
204+
)
205+
eps_base_interp = design_coords.spatial_interp(
206+
array=eps_base_da, interp_method="linear", fill_value="extrapolate"
207+
).data
197208

198209
denom = np.sqrt(np.sum(eps_base_interp.real**2 + eps_base_interp.imag**2))
199210
denom = 1.0 if denom == 0 else denom
@@ -214,9 +225,8 @@ def loss_fn(params_vec: np.ndarray) -> float:
214225
"best_x": params0.ravel().copy(),
215226
}
216227

217-
def callback(intermediate_result):
218-
xk = intermediate_result.x
219-
val = intermediate_result.fun
228+
def callback(xk: np.ndarray):
229+
val = loss_fn(xk)
220230
if val < state["best_val"]:
221231
state["best_val"] = val
222232
state["best_x"] = xk.copy()

0 commit comments

Comments
 (0)