|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import autograd as ag |
| 4 | +import autograd.numpy as anp |
| 5 | +import numpy as np |
| 6 | +import pytest |
| 7 | + |
| 8 | +import tidy3d as td |
| 9 | +import tidy3d.web as web |
| 10 | +from tidy3d.plugins.smatrix.analysis import terminal as terminal_analysis |
| 11 | +from tidy3d.plugins.smatrix.component_modelers.modal import ComponentModeler |
| 12 | +from tidy3d.plugins.smatrix.component_modelers.terminal import TerminalComponentModeler |
| 13 | +from tidy3d.plugins.smatrix.data.data_array import TerminalPortDataArray |
| 14 | +from tidy3d.plugins.smatrix.ports.modal import Port as ModalPort |
| 15 | +from tidy3d.plugins.smatrix.ports.rectangular_lumped import LumpedPort as RectLumpedPort |
| 16 | +from tidy3d.web.api.autograd import autograd as web_ag |
| 17 | + |
| 18 | + |
| 19 | +def _run_emulated_minimal(simulation: td.Simulation, path=None, **kwargs) -> td.SimulationData: |
| 20 | + """Very small offline emulator used by autograd tests. |
| 21 | +
|
| 22 | + - Supports ModeMonitor (amps + n_complex) |
| 23 | + - Supports FieldMonitor (Ex, Ey, Ez, Hx, Hy, Hz) |
| 24 | + - Supports PermittivityMonitor (eps_xx, eps_yy, eps_zz) |
| 25 | + """ |
| 26 | + |
| 27 | + rng = np.random.default_rng(42) |
| 28 | + |
| 29 | + def _coords_for_monitor(sim: td.Simulation, mnt: td.Monitor): |
| 30 | + grid = sim.discretize_monitor(mnt) |
| 31 | + bounds = grid.boundaries.dict() |
| 32 | + |
| 33 | + def centers(arr): |
| 34 | + arr = np.asarray(arr) |
| 35 | + if arr.size < 2: |
| 36 | + return arr |
| 37 | + return 0.5 * (arr[:-1] + arr[1:]) |
| 38 | + |
| 39 | + xyz = {} |
| 40 | + for ax, dim in enumerate("xyz"): |
| 41 | + if mnt.size[ax] == 0: |
| 42 | + xyz[dim] = [mnt.center[ax]] |
| 43 | + else: |
| 44 | + arr = np.asarray(bounds[dim]) |
| 45 | + if arr.size < 2: |
| 46 | + xyz[dim] = [mnt.center[ax]] |
| 47 | + else: |
| 48 | + xyz[dim] = centers(arr) |
| 49 | + |
| 50 | + # ensure at least two points along any nonzero-size axis to avoid empty-grid interpolation |
| 51 | + for ax, dim in enumerate("xyz"): |
| 52 | + if mnt.size[ax] != 0 and len(xyz[dim]) < 2: |
| 53 | + c = float(mnt.center[ax]) |
| 54 | + half = float(mnt.size[ax]) / 2.0 |
| 55 | + if half == 0: |
| 56 | + half = 1e-6 |
| 57 | + eps = max(half * 1e-3, 1e-6) |
| 58 | + xyz[dim] = [c - eps, c + eps] |
| 59 | + return xyz, grid |
| 60 | + |
| 61 | + data_items = [] |
| 62 | + |
| 63 | + for mnt in simulation.monitors: |
| 64 | + if isinstance(mnt, td.ModeMonitor): |
| 65 | + f = list(mnt.freqs) |
| 66 | + mode_index = np.arange(mnt.mode_spec.num_modes) |
| 67 | + directions = np.array(["+", "-"]) |
| 68 | + |
| 69 | + amps_vals = (1 + 0.1j) * rng.random((len(directions), len(f), len(mode_index))) |
| 70 | + n_vals = (1 + 0.05j) * rng.random((len(f), len(mode_index))) |
| 71 | + |
| 72 | + amps = td.ModeAmpsDataArray( |
| 73 | + amps_vals, |
| 74 | + coords={"direction": directions, "f": f, "mode_index": mode_index}, |
| 75 | + ) |
| 76 | + n_complex = td.ModeIndexDataArray(n_vals, coords={"f": f, "mode_index": mode_index}) |
| 77 | + |
| 78 | + data_items.append(td.ModeData(monitor=mnt, amps=amps, n_complex=n_complex)) |
| 79 | + |
| 80 | + elif isinstance(mnt, td.FieldMonitor): |
| 81 | + xyz, grid = _coords_for_monitor(simulation, mnt) |
| 82 | + f = list(mnt.freqs) |
| 83 | + shape = (len(xyz["x"]), len(xyz["y"]), len(xyz["z"]), len(f)) |
| 84 | + |
| 85 | + def cfield(shape=shape, xyz=xyz, f=f, rng=rng): |
| 86 | + vals = (1 + 0.2j) * rng.random(shape) |
| 87 | + return td.ScalarFieldDataArray(vals, coords={**xyz, "f": f}) |
| 88 | + |
| 89 | + data_items.append( |
| 90 | + td.FieldData( |
| 91 | + monitor=mnt, |
| 92 | + grid_expanded=grid, |
| 93 | + Ex=cfield(), |
| 94 | + Ey=cfield(), |
| 95 | + Ez=cfield(), |
| 96 | + Hx=cfield(), |
| 97 | + Hy=cfield(), |
| 98 | + Hz=cfield(), |
| 99 | + symmetry=(0, 0, 0), |
| 100 | + symmetry_center=simulation.center, |
| 101 | + ) |
| 102 | + ) |
| 103 | + |
| 104 | + elif isinstance(mnt, td.PermittivityMonitor): |
| 105 | + xyz, grid = _coords_for_monitor(simulation, mnt) |
| 106 | + f = list(mnt.freqs) |
| 107 | + shape = (len(xyz["x"]), len(xyz["y"]), len(xyz["z"]), len(f)) |
| 108 | + |
| 109 | + def rfield(shape=shape, xyz=xyz, f=f, rng=rng): |
| 110 | + vals = rng.random(shape) |
| 111 | + return td.ScalarFieldDataArray(vals, coords={**xyz, "f": f}) |
| 112 | + |
| 113 | + data_items.append( |
| 114 | + td.PermittivityData( |
| 115 | + monitor=mnt, |
| 116 | + grid_expanded=grid, |
| 117 | + eps_xx=rfield(), |
| 118 | + eps_yy=rfield(), |
| 119 | + eps_zz=rfield(), |
| 120 | + ) |
| 121 | + ) |
| 122 | + |
| 123 | + return td.SimulationData(simulation=simulation, data=tuple(data_items)) |
| 124 | + |
| 125 | + |
| 126 | +def _emulated_run_async_tidy3d(simulations, **kwargs): |
| 127 | + """Batch wrapper around the minimal emulator.""" |
| 128 | + sim_data_map = {} |
| 129 | + for task_name, sim in simulations.items(): |
| 130 | + sim_data_map[task_name] = _run_emulated_minimal(sim) |
| 131 | + |
| 132 | + class _BatchLike(dict): |
| 133 | + def __getitem__(self, key): |
| 134 | + return sim_data_map[key] |
| 135 | + |
| 136 | + return _BatchLike(sim_data_map), {} |
| 137 | + |
| 138 | + |
| 139 | +@pytest.fixture |
| 140 | +def patch_web_autograd_emulator(monkeypatch): |
| 141 | + """Patch web autograd internals to use the local minimal emulator.""" |
| 142 | + |
| 143 | + monkeypatch.setattr(web_ag, "_run_async_tidy3d", _emulated_run_async_tidy3d) |
| 144 | + yield |
| 145 | + |
| 146 | + |
| 147 | +def _build_base_sim(scale: float) -> td.Simulation: |
| 148 | + """Shared base Simulation for modal and terminal modelers.""" |
| 149 | + return td.Simulation( |
| 150 | + size=(4.0, 4.0, 4.0), |
| 151 | + run_time=1e-13, |
| 152 | + grid_spec=td.GridSpec.uniform(dl=0.2), |
| 153 | + boundary_spec=td.BoundarySpec.pml(x=True, y=True, z=True), |
| 154 | + structures=[ |
| 155 | + td.Structure( |
| 156 | + geometry=td.Box(size=(1.0 + scale, 1.0, 1.0), center=(0.0, 0.0, 0.0)), |
| 157 | + medium=td.Medium(permittivity=2.0 + 0.1 * scale), |
| 158 | + ) |
| 159 | + ], |
| 160 | + sources=[], |
| 161 | + monitors=[], |
| 162 | + ) |
| 163 | + |
| 164 | + |
| 165 | +def build_modal_modeler(scale: float) -> ComponentModeler: |
| 166 | + sim = _build_base_sim(scale) |
| 167 | + |
| 168 | + # two modal ports on +/- z sides |
| 169 | + port_size = (2.0, 2.0, 0.0) |
| 170 | + p1 = ModalPort( |
| 171 | + center=(0.0, 0.0, -1.5), |
| 172 | + size=port_size, |
| 173 | + direction="+", |
| 174 | + mode_spec=td.ModeSpec(num_modes=1), |
| 175 | + name="p1", |
| 176 | + ) |
| 177 | + p2 = ModalPort( |
| 178 | + center=(0.0, 0.0, 1.5), |
| 179 | + size=port_size, |
| 180 | + direction="-", |
| 181 | + mode_spec=td.ModeSpec(num_modes=1), |
| 182 | + name="p2", |
| 183 | + ) |
| 184 | + |
| 185 | + freqs = [2.0e14] |
| 186 | + return ComponentModeler(simulation=sim, ports=(p1, p2), freqs=freqs) |
| 187 | + |
| 188 | + |
| 189 | +def build_terminal_modeler(scale: float) -> TerminalComponentModeler: |
| 190 | + sim = _build_base_sim(scale) |
| 191 | + |
| 192 | + # two lumped ports on +/- z sides; injection axis is z |
| 193 | + port_size = (1.0, 1.0, 0.0) |
| 194 | + p1 = RectLumpedPort( |
| 195 | + center=(0.0, 0.0, -1.5), |
| 196 | + size=port_size, |
| 197 | + voltage_axis=1, |
| 198 | + name="lp1", |
| 199 | + impedance=50.0, |
| 200 | + ) |
| 201 | + p2 = RectLumpedPort( |
| 202 | + center=(0.0, 0.0, 1.5), |
| 203 | + size=port_size, |
| 204 | + voltage_axis=1, |
| 205 | + name="lp2", |
| 206 | + impedance=50.0, |
| 207 | + ) |
| 208 | + |
| 209 | + freqs = [2.0e14] |
| 210 | + return TerminalComponentModeler(simulation=sim, ports=(p1, p2), freqs=freqs) |
| 211 | + |
| 212 | + |
| 213 | +def test_component_modeler_autograd_tracing(patch_web_autograd_emulator, tmp_path): |
| 214 | + td.config.logging_level = "ERROR" |
| 215 | + td.config.log_suppression = True |
| 216 | + |
| 217 | + def objective(scale: float) -> float: |
| 218 | + modeler = build_modal_modeler(scale) |
| 219 | + modeler_data = web.run( |
| 220 | + modeler, |
| 221 | + task_name="cm_autograd_test", |
| 222 | + path=tmp_path / "cm_autograd.hdf5", |
| 223 | + verbose=False, |
| 224 | + local_gradient=True, |
| 225 | + ) |
| 226 | + s = modeler_data.smatrix # ModalPortDataArray |
| 227 | + return anp.real(anp.sum(s.values)) |
| 228 | + |
| 229 | + # verify that gradients propagate without error |
| 230 | + g = ag.grad(objective)(1.0) |
| 231 | + assert np.isfinite(g) |
| 232 | + assert not np.isclose(g, 0.0) |
| 233 | + |
| 234 | + |
| 235 | +def test_terminal_component_modeler_autograd_tracing_stubbed( |
| 236 | + patch_web_autograd_emulator, monkeypatch, tmp_path |
| 237 | +): |
| 238 | + """Autograd plumbing test for TerminalComponentModeler using a minimal S-matrix stub. |
| 239 | +
|
| 240 | + This test verifies that web.run autograd integration (forward/adjoint batching and result |
| 241 | + composition) works for terminal modelers when terminal_construct_smatrix is replaced with a |
| 242 | + simple function of FieldData. The full terminal analysis relies on voltage/current integrals |
| 243 | + with interpolation and Yee-grid snapping, which are not autograd-compatible; hence the use of |
| 244 | + a stub to keep the test fast and robust. |
| 245 | + """ |
| 246 | + |
| 247 | + td.config.logging_level = "ERROR" |
| 248 | + td.config.log_suppression = True |
| 249 | + |
| 250 | + # Minimal stub: reduce first available FieldData per port over space (keep f), place on diagonal |
| 251 | + def _fake_terminal_construct_smatrix( |
| 252 | + modeler_data, assume_ideal_excitation=False, s_param_def="pseudo" |
| 253 | + ): |
| 254 | + ports = list(modeler_data.modeler.network_dict.keys()) |
| 255 | + freqs = list(modeler_data.modeler.freqs) |
| 256 | + f_len = len(freqs) |
| 257 | + n = len(ports) |
| 258 | + |
| 259 | + diag_vals = [] |
| 260 | + for p in ports: |
| 261 | + sim_data = modeler_data.data[p] |
| 262 | + vals_f = None |
| 263 | + for d in sim_data.data: |
| 264 | + if isinstance(d, td.FieldData): |
| 265 | + arr = next(iter(d.field_components.values())) |
| 266 | + val_comp = arr.sum(dim=[c for c in arr.dims if c != "f"]).astype(complex) |
| 267 | + if "f" not in val_comp.dims: |
| 268 | + val_comp = td.FreqDataArray( |
| 269 | + np.ones((f_len,), dtype=complex), coords={"f": freqs} |
| 270 | + ) |
| 271 | + vals_f = val_comp |
| 272 | + break |
| 273 | + if vals_f is None: |
| 274 | + vals_f = td.FreqDataArray(np.ones((f_len,), dtype=complex), coords={"f": freqs}) |
| 275 | + diag_vals.append(vals_f) |
| 276 | + |
| 277 | + data = anp.zeros((f_len, n, n), dtype=complex) |
| 278 | + for i, s in enumerate(diag_vals): |
| 279 | + Ei = np.zeros((n, n)) |
| 280 | + Ei[i, i] = 1.0 |
| 281 | + data = data + anp.einsum("f,ij->fij", s.values, Ei) |
| 282 | + |
| 283 | + return TerminalPortDataArray(data, coords={"f": freqs, "port_out": ports, "port_in": ports}) |
| 284 | + |
| 285 | + monkeypatch.setattr( |
| 286 | + terminal_analysis, "terminal_construct_smatrix", _fake_terminal_construct_smatrix |
| 287 | + ) |
| 288 | + |
| 289 | + def objective(scale: float) -> float: |
| 290 | + modeler = build_terminal_modeler(scale) |
| 291 | + _ = tmp_path / "cm_terminal_autograd.hdf5" |
| 292 | + modeler_data = web.run( |
| 293 | + modeler, |
| 294 | + task_name="cm_terminal_autograd_test", |
| 295 | + path=str(_), |
| 296 | + verbose=False, |
| 297 | + local_gradient=True, |
| 298 | + ) |
| 299 | + s_vals = modeler_data.smatrix().data.values |
| 300 | + return anp.real(anp.sum(s_vals)) |
| 301 | + |
| 302 | + g = ag.grad(objective)(1.0) |
| 303 | + assert np.isfinite(g) |
| 304 | + assert not np.isclose(g, 0.0) |
0 commit comments