11from __future__ import annotations
22
33from collections import deque
4- from typing import Callable , Optional , Union
4+ from typing import Callable , Literal , Optional , Union
55
66import autograd .numpy as np
77import pydantic .v1 as pd
@@ -110,6 +110,7 @@ def initialize_params_from_simulation(
110110 params0 : np .ndarray ,
111111 * ,
112112 freq : Optional [float ] = None ,
113+ outside_handling : Literal ["extrapolate" , "mask" , "nan" ] = "mask" ,
113114 maxiter : int = 100 ,
114115 bounds : tuple [Optional [float ], Optional [float ]] = (0.0 , 1.0 ),
115116 rel_improve_tol : float = 1e-3 ,
@@ -133,6 +134,10 @@ def initialize_params_from_simulation(
133134 using coordinate-aware interpolation (no per-iteration interpolation).
134135 - Early stopping uses a single knob ``rel_improve_tol``; optimization stops once the
135136 relative improvement over a small fixed window falls below this value.
137+ - Points outside the base‑epsilon coverage can be handled via ``outside_handling``:
138+ ``'extrapolate'`` (use nearest extrapolation, default in earlier versions),
139+ ``'mask'`` (ignore outside points using a coverage mask; default), or
140+ ``'nan'`` (non‑extended grid; treat outside as NaN and ignore in the loss).
136141
137142 Parameters
138143 ----------
@@ -148,6 +153,12 @@ def initialize_params_from_simulation(
148153 Maximum number of L‑BFGS‑B iterations.
149154 freq : float, optional
150155 Frequency at which permittivity is evaluated. If ``None``, uses infinite frequency.
156+ outside_handling : {"extrapolate", "mask", "nan"} = "mask"
157+ Strategy for points where design coordinates fall outside the sampled base epsilon:
158+ - "extrapolate": include them using nearest extrapolation.
159+ - "mask": include only points within the coverage bounds of ``sim.epsilon`` on the
160+ extended subgrid.
161+ - "nan": sample on a non‑extended subgrid and ignore points where interpolation returns NaN.
151162 bounds : tuple[float | None, float | None] = (0.0, 1.0)
152163 Element‑wise parameter bounds, e.g. ``(0.0, 1.0)`` or ``(-1.0, 1.0)``. Use ``None`` to
153164 indicate an unbounded side.
@@ -193,7 +204,13 @@ def initialize_params_from_simulation(
193204 (3, 3)
194205 """
195206 structure_init = param_to_structure (params0 )
196- subgrid = sim .discretize (structure_init .geometry , extend = True )
207+
208+ if outside_handling not in ("extrapolate" , "mask" , "nan" ):
209+ raise ValueError ("'outside_handling' must be one of {'extrapolate', 'mask', 'nan'}." )
210+
211+ # build base epsilon and interpolate onto design coords
212+ extend_flag = outside_handling != "nan"
213+ subgrid = sim .discretize (structure_init .geometry , extend = extend_flag )
197214 eps_base_da = sim .epsilon_on_grid (grid = subgrid , coord_key = "centers" , freq = freq )
198215
199216 design_eps_da = structure_init .medium .permittivity
@@ -202,18 +219,53 @@ def initialize_params_from_simulation(
202219 y = np .array (design_eps_da .coords ["y" ]),
203220 z = np .array (design_eps_da .coords ["z" ]),
204221 )
205- eps_base_interp = design_coords .spatial_interp (
206- array = eps_base_da , interp_method = "linear" , fill_value = "extrapolate"
207- ).data
208222
209- denom = np .sqrt (np .sum (eps_base_interp .real ** 2 + eps_base_interp .imag ** 2 ))
223+ if outside_handling == "nan" :
224+ eps_base_interp = design_coords .spatial_interp (
225+ array = eps_base_da , interp_method = "linear" , fill_value = np .nan
226+ ).data
227+ mask = np .isfinite (eps_base_interp )
228+ else :
229+ eps_base_interp = design_coords .spatial_interp (
230+ array = eps_base_da , interp_method = "linear" , fill_value = "extrapolate"
231+ ).data
232+ if outside_handling == "mask" :
233+ # build mask from coverage bounds of base epsilon coordinates
234+ xb , yb , zb = [np .array (eps_base_da .coords [d ]) for d in ("x" , "y" , "z" )]
235+ xd , yd , zd = [np .array (design_eps_da .coords [d ]) for d in ("x" , "y" , "z" )]
236+ mask_x = (xd >= np .min (xb )) & (xd <= np .max (xb ))
237+ mask_y = (yd >= np .min (yb )) & (yd <= np .max (yb ))
238+ mask_z = (zd >= np .min (zb )) & (zd <= np .max (zb ))
239+ mask = (mask_x [:, None , None ]) & (mask_y [None , :, None ]) & (mask_z [None , None , :])
240+ else :
241+ mask = None
242+
243+ if mask is not None :
244+ covered = int (np .sum (mask ))
245+ total = int (np .prod (mask .shape ))
246+ frac = covered / max (total , 1 )
247+ if frac < 0.9 :
248+ td .log .warning (
249+ f"Only { frac :.1%} of design points are covered by base epsilon sampling. "
250+ "Consider adding a 'MeshOverrideStructure' or adjusting design coordinates."
251+ )
252+
253+ if mask is None :
254+ denom = np .sqrt (np .sum (eps_base_interp .real ** 2 + eps_base_interp .imag ** 2 ))
255+ else :
256+ denom = np .sqrt (
257+ np .sum ((eps_base_interp .real [mask ]) ** 2 + (eps_base_interp .imag [mask ]) ** 2 )
258+ )
210259 denom = 1.0 if denom == 0 else denom
211260
212261 def loss_fn (params_vec : np .ndarray ) -> float :
213262 params = params_vec .reshape (params0 .shape )
214263 structure = param_to_structure (params )
215264 eps_design = structure .medium .permittivity .data
216- res = eps_base_interp - eps_design
265+ if mask is None :
266+ res = eps_base_interp - eps_design
267+ return 0.5 * np .sum (res .real ** 2 + res .imag ** 2 ) / denom
268+ res = (eps_base_interp - eps_design )[mask ]
217269 return 0.5 * np .sum (res .real ** 2 + res .imag ** 2 ) / denom
218270
219271 val_and_grad = value_and_grad (loss_fn )
0 commit comments