88import warnings
99from contextlib import ExitStack , contextmanager
1010from types import MethodType
11- from typing import Any , Optional , Union
11+ from typing import Any , Callable , Optional , Union
1212
1313import torch
1414from torch ._dynamo .functional_export import _dynamo_graph_capture_for_export
@@ -159,7 +159,9 @@ def enable_local_map_wrapping():
159159 yield
160160
161161
162- def _export (model : torch .nn .Module , inputs : tuple [Any ]) -> torch .nn .Module :
162+ def _export (
163+ model : torch .nn .Module , model_wrapper : Callable , inputs : tuple [Any ]
164+ ) -> torch .fx .GraphModule :
163165 """
164166 Thin wrapper around graph capture output that restores the
165167 original calling convention and attribute fqn. TODO:
@@ -169,7 +171,7 @@ def _export(model: torch.nn.Module, inputs: tuple[Any]) -> torch.nn.Module:
169171 3) Be more careful about tensor constants names.
170172 """
171173 with torch ._dynamo .config .patch (install_free_tensors = True ):
172- gm = _dynamo_graph_capture_for_export (model )(* inputs )
174+ gm = _dynamo_graph_capture_for_export (model_wrapper )(* inputs )
173175 _restore_state_dict (model , gm )
174176 return gm
175177
@@ -193,6 +195,7 @@ def __init__(
193195 ac_stage_size_in_GiB : Optional [Union [float , str ]] = "auto" ,
194196 reshard_after_forward : bool = True ,
195197 dynamic : bool = False ,
198+ loss_fn : Optional [Callable ] = None ,
196199 ** kwargs ,
197200 ):
198201 self .stack = ExitStack ()
@@ -224,6 +227,7 @@ def __init__(
224227 self .enable_ac = enable_ac
225228 self .ac_stage_size_in_GiB = ac_stage_size_in_GiB
226229 self .reshard_after_forward = reshard_after_forward
230+ self .loss_fn = loss_fn
227231
228232 if dynamic :
229233 self .fake_mode .shape_env = ShapeEnv ()
@@ -294,11 +298,27 @@ def build_model_graph(self):
294298 inputs = self .input_fn ()
295299 if not isinstance (inputs , tuple ):
296300 inputs = (inputs ,)
301+ model_wrapper : Callable
302+ if self .loss_fn is not None :
303+
304+ def model_with_loss (input , target ) -> Any :
305+ output = self .model (input )
306+ loss = self .loss_fn (output , target ) # type: ignore[misc]
307+ return loss
308+
309+ model_wrapper = model_with_loss
310+ else :
311+
312+ def model_wo_loss (input ) -> Any :
313+ output = self .model (input )
314+ return output
315+
316+ model_wrapper = model_wo_loss
297317
298318 with set_dtype_cast (
299319 True
300320 ), enable_local_map_wrapping (), torch ._dynamo .utils ._disable_saved_tensors_hooks_during_tracing ():
301- torch_ir_with_fqn = _export (self .model , inputs )
321+ torch_ir_with_fqn = _export (self .model , model_wrapper , inputs )
302322 # TODO Cna't use fake mode here because it clashes with the user level
303323 # fake mode. Ideally dynamo should reuse the user level fake mode.
304324 self .joint_with_descriptors = aot_export_joint_with_descriptors (
@@ -326,6 +346,7 @@ def build_model_graph(self):
326346 print_output = False , include_stride = True , include_device = True
327347 ),
328348 )
349+ print (gm .graph )
329350
330351 self .gm = gm
331352
0 commit comments