Skip to content

Commit bd88ce0

Browse files
Propagate latest changes to autograd CM (#2755)
* Update to new data name * Remove polluting scope * no issue mixing import scope on run.py * Update tidy3d/plugins/smatrix/component_modelers/base.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Readd PayType to run.py * Correct fix --------- Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent ea6840c commit bd88ce0

File tree

5 files changed

+68
-89
lines changed

5 files changed

+68
-89
lines changed

docs/faq

tests/test_plugins/smatrix/test_component_modeler_autograd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import tidy3d as td
99
import tidy3d.web as web
1010
from tidy3d.plugins.smatrix.analysis import terminal as terminal_analysis
11-
from tidy3d.plugins.smatrix.component_modelers.modal import ComponentModeler
11+
from tidy3d.plugins.smatrix.component_modelers.modal import ModalComponentModelerData
1212
from tidy3d.plugins.smatrix.component_modelers.terminal import TerminalComponentModeler
1313
from tidy3d.plugins.smatrix.data.data_array import TerminalPortDataArray
1414
from tidy3d.plugins.smatrix.ports.modal import Port as ModalPort
@@ -162,7 +162,7 @@ def _build_base_sim(scale: float) -> td.Simulation:
162162
)
163163

164164

165-
def build_modal_modeler(scale: float) -> ComponentModeler:
165+
def build_modal_modeler(scale: float) -> ModalComponentModelerData:
166166
sim = _build_base_sim(scale)
167167

168168
# two modal ports on +/- z sides
@@ -183,7 +183,7 @@ def build_modal_modeler(scale: float) -> ComponentModeler:
183183
)
184184

185185
freqs = [2.0e14]
186-
return ComponentModeler(simulation=sim, ports=(p1, p2), freqs=freqs)
186+
return ModalComponentModelerData(simulation=sim, ports=(p1, p2), freqs=freqs)
187187

188188

189189
def build_terminal_modeler(scale: float) -> TerminalComponentModeler:

tidy3d/plugins/smatrix/component_modelers/base.py

Lines changed: 4 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33
from __future__ import annotations
44

5-
import os
65
from abc import ABC, abstractmethod
7-
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union, get_args
6+
from typing import Generic, Optional, TypeVar, Union, get_args
87

98
import numpy as np
109
import pydantic.v1 as pd
@@ -27,6 +26,8 @@
2726
from tidy3d.plugins.smatrix.ports.modal import Port
2827
from tidy3d.plugins.smatrix.ports.types import TerminalPortType
2928
from tidy3d.plugins.smatrix.ports.wave import WavePort
29+
30+
# DO NOT import from web if it can be avoided, to avoid circular imports
3031
from tidy3d.web.core.types import PayType
3132

3233
# fwidth of gaussian pulse in units of central frequency
@@ -37,73 +38,6 @@
3738
IndexType = TypeVar("IndexType")
3839
ElementType = TypeVar("ElementType")
3940

40-
if TYPE_CHECKING:
41-
from tidy3d.plugins.smatrix.data.types import ComponentModelerDataType
42-
43-
from .types import ComponentModelerType
44-
45-
46-
def _run_component_modeler(
47-
modeler: ComponentModelerType,
48-
task_name: str,
49-
folder_name: str,
50-
path: str,
51-
callback_url: Optional[str],
52-
verbose: bool,
53-
solver_version: Optional[str],
54-
local_gradient: bool,
55-
max_num_adjoint_per_fwd: int,
56-
pay_type: Union[PayType, str],
57-
) -> ComponentModelerDataType:
58-
"""Run a Component Modeler via autograd by batching its underlying simulations."""
59-
60-
from tidy3d.web.api.autograd.autograd import DEFAULT_DATA_DIR, _run_async
61-
62-
path_dir = os.dirname(path) if path else DEFAULT_DATA_DIR
63-
if not path_dir:
64-
path_dir = DEFAULT_DATA_DIR
65-
66-
sims = modeler.sim_dict
67-
68-
sim_data_map = _run_async(
69-
simulations=sims,
70-
folder_name=folder_name,
71-
path_dir=path_dir,
72-
callback_url=callback_url,
73-
verbose=verbose,
74-
simulation_type="tidy3d_autograd_async",
75-
solver_version=solver_version,
76-
parent_tasks=None,
77-
local_gradient=local_gradient,
78-
max_num_adjoint_per_fwd=max_num_adjoint_per_fwd,
79-
pay_type=pay_type,
80-
)
81-
82-
return _compose_modeler_data_from_sim_map(modeler=modeler, sim_data_map=sim_data_map)
83-
84-
85-
def _compose_modeler_data_from_sim_map(
86-
modeler: ComponentModelerType, sim_data_map: dict
87-
) -> ComponentModelerDataType:
88-
"""Create ComponentModelerDataType from a dict of SimulationData keyed by task name."""
89-
90-
# local imports to avoid cycles through tidy3d.web
91-
from tidy3d.components.data.index import IndexSimulationData
92-
from tidy3d.plugins.smatrix.component_modelers.modal import ComponentModeler
93-
from tidy3d.plugins.smatrix.component_modelers.terminal import TerminalComponentModeler
94-
from tidy3d.plugins.smatrix.data.modal import ComponentModelerData
95-
from tidy3d.plugins.smatrix.data.terminal import TerminalComponentModelerData
96-
97-
# preserve mapping order
98-
index = tuple(sim_data_map.keys())
99-
data = tuple(sim_data_map.values())
100-
indexed = IndexSimulationData(index=index, data=data)
101-
102-
if isinstance(modeler, ComponentModeler):
103-
return ComponentModelerData(modeler=modeler, data=indexed)
104-
if isinstance(modeler, TerminalComponentModeler):
105-
return TerminalComponentModelerData(modeler=modeler, data=indexed)
106-
10741

10842
class AbstractComponentModeler(ABC, Generic[IndexType, ElementType], Tidy3dBaseModel):
10943
"""Tool for modeling devices and computing port parameters."""
@@ -298,6 +232,7 @@ def run(
298232
deprecation_warning: bool = True,
299233
):
300234
"""Run component modeler locally, with autograd support."""
235+
from tidy3d.plugins.smatrix.run import _run_component_modeler
301236

302237
if deprecation_warning:
303238
log.warning(

tidy3d/plugins/smatrix/run.py

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,40 @@
22

33
import json
44
import os
5-
from typing import Optional
5+
from typing import Optional, Union
66

77
from tidy3d.components.base import Tidy3dBaseModel
8+
from tidy3d.components.data.index import SimulationDataMap
89
from tidy3d.plugins.smatrix.component_modelers.modal import ModalComponentModeler
910
from tidy3d.plugins.smatrix.component_modelers.terminal import TerminalComponentModeler
1011
from tidy3d.plugins.smatrix.component_modelers.types import (
1112
ComponentModelerType,
1213
)
13-
from tidy3d.plugins.smatrix.data.modal import ModalComponentModelerData, SimulationDataMap
14+
from tidy3d.plugins.smatrix.data.modal import ModalComponentModelerData
1415
from tidy3d.plugins.smatrix.data.terminal import TerminalComponentModelerData
1516
from tidy3d.plugins.smatrix.data.types import ComponentModelerDataType
1617
from tidy3d.web import Batch, BatchData
18+
from tidy3d.web.api.autograd.autograd import DEFAULT_DATA_DIR, _run_async
19+
from tidy3d.web.core.types import PayType
1720

18-
DEFAULT_DATA_DIR = "."
1921

22+
def compose_simulation_data_map(sim_data_map: dict) -> SimulationDataMap:
23+
# preserve mapping order
24+
index = tuple(sim_data_map.keys())
25+
data = tuple(sim_data_map.values())
26+
indexed = SimulationDataMap(keys=index, values=data)
27+
return indexed
2028

21-
def compose_simulation_data_index(port_task_map: dict[str, str]) -> SimulationDataMap:
22-
port_data_dict = {}
23-
for _, _ in port_task_map.items():
24-
pass
25-
# FIXME: get simulationdata for each port
26-
# port_data_dict[port] = sim_data_i
2729

28-
return SimulationDataMap(
29-
keys=tuple(port_data_dict.keys()), values=tuple(port_data_dict.values())
30-
)
30+
def _compose_modeler_data_from_sim_map(
31+
modeler: ComponentModelerType, sim_data_map: dict
32+
) -> ComponentModelerDataType:
33+
"""Create ComponentModelerDataType from a dict of SimulationData keyed by task name."""
34+
indexed = compose_simulation_data_map(sim_data_map)
35+
if isinstance(modeler, ModalComponentModeler):
36+
return ModalComponentModelerData(modeler=modeler, data=indexed)
37+
if isinstance(modeler, TerminalComponentModeler):
38+
return TerminalComponentModelerData(modeler=modeler, data=indexed)
3139

3240

3341
def compose_terminal_modeler_data(
@@ -45,7 +53,7 @@ def compose_terminal_modeler_data(
4553
A `TerminalComponentModelerData` object containing the results mapped to
4654
their respective ports.
4755
"""
48-
port_simulation_data = compose_simulation_data_index(port_task_map)
56+
port_simulation_data = compose_simulation_data_map(port_task_map)
4957
return TerminalComponentModelerData(modeler=modeler, data=port_simulation_data)
5058

5159

@@ -65,7 +73,7 @@ def compose_component_modeler_data(
6573
A `ModalComponentModelerData` object containing the results mapped to
6674
their respective ports.
6775
"""
68-
port_simulation_data = compose_simulation_data_index(port_task_map)
76+
port_simulation_data = compose_simulation_data_map(port_task_map)
6977
return ModalComponentModelerData(modeler=modeler, data=port_simulation_data)
7078

7179

@@ -98,7 +106,7 @@ def compose_modeler(
98106
elif modeler_type == "TerminalComponentModeler":
99107
modeler = TerminalComponentModeler.from_file(modeler_file)
100108
else:
101-
raise TypeError(f"Unsupported modeler type: {type(modeler).__name__}")
109+
raise TypeError(f"Unsupported modeler type: {modeler_type}")
102110
return modeler
103111

104112

@@ -273,3 +281,39 @@ def run(
273281
batch_data = batch.run()
274282
modeler_data = compose_modeler_data_from_batch_data(modeler=modeler, batch_data=batch_data)
275283
return modeler_data
284+
285+
286+
def _run_component_modeler(
287+
modeler: ComponentModelerType,
288+
task_name: str,
289+
folder_name: str,
290+
path: str,
291+
callback_url: Optional[str],
292+
verbose: bool,
293+
solver_version: Optional[str],
294+
local_gradient: bool,
295+
max_num_adjoint_per_fwd: int,
296+
pay_type: Union[PayType, str],
297+
) -> ComponentModelerDataType:
298+
"""Run a Component Modeler via autograd by batching its underlying simulations."""
299+
path_dir = os.dirname(path) if path else DEFAULT_DATA_DIR
300+
if not path_dir:
301+
path_dir = DEFAULT_DATA_DIR
302+
303+
sims = modeler.sim_dict
304+
305+
sim_data_map = _run_async(
306+
simulations=sims,
307+
folder_name=folder_name,
308+
path_dir=path_dir,
309+
callback_url=callback_url,
310+
verbose=verbose,
311+
simulation_type="tidy3d_autograd_async",
312+
solver_version=solver_version,
313+
parent_tasks=None,
314+
local_gradient=local_gradient,
315+
max_num_adjoint_per_fwd=max_num_adjoint_per_fwd,
316+
pay_type=pay_type,
317+
)
318+
319+
return _compose_modeler_data_from_sim_map(modeler=modeler, sim_data_map=sim_data_map)

0 commit comments

Comments
 (0)