Skip to content

Commit 198d60c

Browse files
committed
wip: autograd support for web.run(component_modeler)
1 parent e404a1e commit 198d60c

File tree

2 files changed

+450
-26
lines changed

2 files changed

+450
-26
lines changed
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
from __future__ import annotations
2+
3+
import autograd as ag
4+
import autograd.numpy as anp
5+
import numpy as np
6+
import pytest
7+
8+
import tidy3d as td
9+
import tidy3d.web as web
10+
from tidy3d.plugins.smatrix.analysis import terminal as terminal_analysis
11+
from tidy3d.plugins.smatrix.component_modelers.modal import ComponentModeler
12+
from tidy3d.plugins.smatrix.component_modelers.terminal import TerminalComponentModeler
13+
from tidy3d.plugins.smatrix.data.data_array import TerminalPortDataArray
14+
from tidy3d.plugins.smatrix.ports.modal import Port as ModalPort
15+
from tidy3d.plugins.smatrix.ports.rectangular_lumped import LumpedPort as RectLumpedPort
16+
from tidy3d.web.api.autograd import autograd as web_ag
17+
18+
19+
def _run_emulated_minimal(simulation: td.Simulation, path=None, **kwargs) -> td.SimulationData:
20+
"""Very small offline emulator used by autograd tests.
21+
22+
- Supports ModeMonitor (amps + n_complex)
23+
- Supports FieldMonitor (Ex, Ey, Ez, Hx, Hy, Hz)
24+
- Supports PermittivityMonitor (eps_xx, eps_yy, eps_zz)
25+
"""
26+
27+
rng = np.random.default_rng(42)
28+
29+
def _coords_for_monitor(sim: td.Simulation, mnt: td.Monitor):
30+
grid = sim.discretize_monitor(mnt)
31+
bounds = grid.boundaries.dict()
32+
33+
def centers(arr):
34+
arr = np.asarray(arr)
35+
if arr.size < 2:
36+
return arr
37+
return 0.5 * (arr[:-1] + arr[1:])
38+
39+
xyz = {}
40+
for ax, dim in enumerate("xyz"):
41+
if mnt.size[ax] == 0:
42+
xyz[dim] = [mnt.center[ax]]
43+
else:
44+
arr = np.asarray(bounds[dim])
45+
if arr.size < 2:
46+
xyz[dim] = [mnt.center[ax]]
47+
else:
48+
xyz[dim] = centers(arr)
49+
50+
# ensure at least two points along any nonzero-size axis to avoid empty-grid interpolation
51+
for ax, dim in enumerate("xyz"):
52+
if mnt.size[ax] != 0 and len(xyz[dim]) < 2:
53+
c = float(mnt.center[ax])
54+
half = float(mnt.size[ax]) / 2.0
55+
if half == 0:
56+
half = 1e-6
57+
eps = max(half * 1e-3, 1e-6)
58+
xyz[dim] = [c - eps, c + eps]
59+
return xyz, grid
60+
61+
data_items = []
62+
63+
for mnt in simulation.monitors:
64+
if isinstance(mnt, td.ModeMonitor):
65+
f = list(mnt.freqs)
66+
mode_index = np.arange(mnt.mode_spec.num_modes)
67+
directions = np.array(["+", "-"])
68+
69+
amps_vals = (1 + 0.1j) * rng.random((len(directions), len(f), len(mode_index)))
70+
n_vals = (1 + 0.05j) * rng.random((len(f), len(mode_index)))
71+
72+
amps = td.ModeAmpsDataArray(
73+
amps_vals,
74+
coords={"direction": directions, "f": f, "mode_index": mode_index},
75+
)
76+
n_complex = td.ModeIndexDataArray(n_vals, coords={"f": f, "mode_index": mode_index})
77+
78+
data_items.append(td.ModeData(monitor=mnt, amps=amps, n_complex=n_complex))
79+
80+
elif isinstance(mnt, td.FieldMonitor):
81+
xyz, grid = _coords_for_monitor(simulation, mnt)
82+
f = list(mnt.freqs)
83+
shape = (len(xyz["x"]), len(xyz["y"]), len(xyz["z"]), len(f))
84+
85+
def cfield(shape=shape, xyz=xyz, f=f, rng=rng):
86+
vals = (1 + 0.2j) * rng.random(shape)
87+
return td.ScalarFieldDataArray(vals, coords={**xyz, "f": f})
88+
89+
data_items.append(
90+
td.FieldData(
91+
monitor=mnt,
92+
grid_expanded=grid,
93+
Ex=cfield(),
94+
Ey=cfield(),
95+
Ez=cfield(),
96+
Hx=cfield(),
97+
Hy=cfield(),
98+
Hz=cfield(),
99+
symmetry=(0, 0, 0),
100+
symmetry_center=simulation.center,
101+
)
102+
)
103+
104+
elif isinstance(mnt, td.PermittivityMonitor):
105+
xyz, grid = _coords_for_monitor(simulation, mnt)
106+
f = list(mnt.freqs)
107+
shape = (len(xyz["x"]), len(xyz["y"]), len(xyz["z"]), len(f))
108+
109+
def rfield(shape=shape, xyz=xyz, f=f, rng=rng):
110+
vals = rng.random(shape)
111+
return td.ScalarFieldDataArray(vals, coords={**xyz, "f": f})
112+
113+
data_items.append(
114+
td.PermittivityData(
115+
monitor=mnt,
116+
grid_expanded=grid,
117+
eps_xx=rfield(),
118+
eps_yy=rfield(),
119+
eps_zz=rfield(),
120+
)
121+
)
122+
123+
return td.SimulationData(simulation=simulation, data=tuple(data_items))
124+
125+
126+
def _emulated_run_async_tidy3d(simulations, **kwargs):
127+
"""Batch wrapper around the minimal emulator."""
128+
sim_data_map = {}
129+
for task_name, sim in simulations.items():
130+
sim_data_map[task_name] = _run_emulated_minimal(sim)
131+
132+
class _BatchLike(dict):
133+
def __getitem__(self, key):
134+
return sim_data_map[key]
135+
136+
return _BatchLike(sim_data_map), {}
137+
138+
139+
@pytest.fixture
140+
def patch_web_autograd_emulator(monkeypatch):
141+
"""Patch web autograd internals to use the local minimal emulator."""
142+
143+
monkeypatch.setattr(web_ag, "_run_async_tidy3d", _emulated_run_async_tidy3d)
144+
yield
145+
146+
147+
def _build_base_sim(scale: float) -> td.Simulation:
148+
"""Shared base Simulation for modal and terminal modelers."""
149+
return td.Simulation(
150+
size=(4.0, 4.0, 4.0),
151+
run_time=1e-13,
152+
grid_spec=td.GridSpec.uniform(dl=0.2),
153+
boundary_spec=td.BoundarySpec.pml(x=True, y=True, z=True),
154+
structures=[
155+
td.Structure(
156+
geometry=td.Box(size=(1.0 + scale, 1.0, 1.0), center=(0.0, 0.0, 0.0)),
157+
medium=td.Medium(permittivity=2.0 + 0.1 * scale),
158+
)
159+
],
160+
sources=[],
161+
monitors=[],
162+
)
163+
164+
165+
def build_modal_modeler(scale: float) -> ComponentModeler:
166+
sim = _build_base_sim(scale)
167+
168+
# two modal ports on +/- z sides
169+
port_size = (2.0, 2.0, 0.0)
170+
p1 = ModalPort(
171+
center=(0.0, 0.0, -1.5),
172+
size=port_size,
173+
direction="+",
174+
mode_spec=td.ModeSpec(num_modes=1),
175+
name="p1",
176+
)
177+
p2 = ModalPort(
178+
center=(0.0, 0.0, 1.5),
179+
size=port_size,
180+
direction="-",
181+
mode_spec=td.ModeSpec(num_modes=1),
182+
name="p2",
183+
)
184+
185+
freqs = [2.0e14]
186+
return ComponentModeler(simulation=sim, ports=(p1, p2), freqs=freqs)
187+
188+
189+
def build_terminal_modeler(scale: float) -> TerminalComponentModeler:
190+
sim = _build_base_sim(scale)
191+
192+
# two lumped ports on +/- z sides; injection axis is z
193+
port_size = (1.0, 1.0, 0.0)
194+
p1 = RectLumpedPort(
195+
center=(0.0, 0.0, -1.5),
196+
size=port_size,
197+
voltage_axis=1,
198+
name="lp1",
199+
impedance=50.0,
200+
)
201+
p2 = RectLumpedPort(
202+
center=(0.0, 0.0, 1.5),
203+
size=port_size,
204+
voltage_axis=1,
205+
name="lp2",
206+
impedance=50.0,
207+
)
208+
209+
freqs = [2.0e14]
210+
return TerminalComponentModeler(simulation=sim, ports=(p1, p2), freqs=freqs)
211+
212+
213+
def test_component_modeler_autograd_tracing(patch_web_autograd_emulator, tmp_path):
214+
td.config.logging_level = "ERROR"
215+
td.config.log_suppression = True
216+
217+
def objective(scale: float) -> float:
218+
modeler = build_modal_modeler(scale)
219+
modeler_data = web.run(
220+
modeler,
221+
task_name="cm_autograd_test",
222+
path=tmp_path / "cm_autograd.hdf5",
223+
verbose=False,
224+
local_gradient=True,
225+
)
226+
s = modeler_data.smatrix # ModalPortDataArray
227+
return anp.real(anp.sum(s.values))
228+
229+
# verify that gradients propagate without error
230+
g = ag.grad(objective)(1.0)
231+
assert np.isfinite(g)
232+
assert not np.isclose(g, 0.0)
233+
234+
235+
def test_terminal_component_modeler_autograd_tracing_stubbed(
236+
patch_web_autograd_emulator, monkeypatch, tmp_path
237+
):
238+
"""Autograd plumbing test for TerminalComponentModeler using a minimal S-matrix stub.
239+
240+
This test verifies that web.run autograd integration (forward/adjoint batching and result
241+
composition) works for terminal modelers when terminal_construct_smatrix is replaced with a
242+
simple function of FieldData. The full terminal analysis relies on voltage/current integrals
243+
with interpolation and Yee-grid snapping, which are not autograd-compatible; hence the use of
244+
a stub to keep the test fast and robust.
245+
"""
246+
247+
td.config.logging_level = "ERROR"
248+
td.config.log_suppression = True
249+
250+
# Minimal stub: reduce first available FieldData per port over space (keep f), place on diagonal
251+
def _fake_terminal_construct_smatrix(
252+
modeler_data, assume_ideal_excitation=False, s_param_def="pseudo"
253+
):
254+
ports = list(modeler_data.modeler.network_dict.keys())
255+
freqs = list(modeler_data.modeler.freqs)
256+
f_len = len(freqs)
257+
n = len(ports)
258+
259+
diag_vals = []
260+
for p in ports:
261+
sim_data = modeler_data.data[p]
262+
vals_f = None
263+
for d in sim_data.data:
264+
if isinstance(d, td.FieldData):
265+
arr = next(iter(d.field_components.values()))
266+
val_comp = arr.sum(dim=[c for c in arr.dims if c != "f"]).astype(complex)
267+
if "f" not in val_comp.dims:
268+
val_comp = td.FreqDataArray(
269+
np.ones((f_len,), dtype=complex), coords={"f": freqs}
270+
)
271+
vals_f = val_comp
272+
break
273+
if vals_f is None:
274+
vals_f = td.FreqDataArray(np.ones((f_len,), dtype=complex), coords={"f": freqs})
275+
diag_vals.append(vals_f)
276+
277+
data = anp.zeros((f_len, n, n), dtype=complex)
278+
for i, s in enumerate(diag_vals):
279+
Ei = np.zeros((n, n))
280+
Ei[i, i] = 1.0
281+
data = data + anp.einsum("f,ij->fij", s.values, Ei)
282+
283+
return TerminalPortDataArray(data, coords={"f": freqs, "port_out": ports, "port_in": ports})
284+
285+
monkeypatch.setattr(
286+
terminal_analysis, "terminal_construct_smatrix", _fake_terminal_construct_smatrix
287+
)
288+
289+
def objective(scale: float) -> float:
290+
modeler = build_terminal_modeler(scale)
291+
_ = tmp_path / "cm_terminal_autograd.hdf5"
292+
modeler_data = web.run(
293+
modeler,
294+
task_name="cm_terminal_autograd_test",
295+
path=str(_),
296+
verbose=False,
297+
local_gradient=True,
298+
)
299+
s_vals = modeler_data.smatrix().data.values
300+
return anp.real(anp.sum(s_vals))
301+
302+
g = ag.grad(objective)(1.0)
303+
assert np.isfinite(g)
304+
assert not np.isclose(g, 0.0)

0 commit comments

Comments
 (0)