88import warnings
99from contextlib import ExitStack , contextmanager
1010from types import MethodType
11- from typing import Any , Optional , Union
11+ from typing import Any , Callable , Optional , Union
1212
1313import torch
1414from torch ._dynamo .functional_export import _dynamo_graph_capture_for_export
@@ -159,7 +159,9 @@ def enable_local_map_wrapping():
159159 yield
160160
161161
162- def _export (model : torch .nn .Module , inputs : tuple [Any ]) -> torch .nn .Module :
162+ def _export (
163+ model : torch .nn .Module , model_wrapper : Callable , inputs : tuple [Any ]
164+ ) -> torch .fx .GraphModule :
163165 """
164166 Thin wrapper around graph capture output that restores the
165167 original calling convention and attribute fqn. TODO:
@@ -169,7 +171,7 @@ def _export(model: torch.nn.Module, inputs: tuple[Any]) -> torch.nn.Module:
169171 3) Be more careful about tensor constants names.
170172 """
171173 with torch ._dynamo .config .patch (install_free_tensors = True ):
172- gm = _dynamo_graph_capture_for_export (model )(* inputs )
174+ gm = _dynamo_graph_capture_for_export (model_wrapper )(* inputs )
173175 _restore_state_dict (model , gm )
174176 return gm
175177
@@ -193,6 +195,7 @@ def __init__(
193195 ac_stage_size_in_GiB : Optional [Union [float , str ]] = "auto" ,
194196 reshard_after_forward : bool = True ,
195197 dynamic : bool = False ,
198+ loss_fn : Optional [Callable ] = None ,
196199 ** kwargs ,
197200 ):
198201 self .stack = ExitStack ()
@@ -224,6 +227,7 @@ def __init__(
224227 self .enable_ac = enable_ac
225228 self .ac_stage_size_in_GiB = ac_stage_size_in_GiB
226229 self .reshard_after_forward = reshard_after_forward
230+ self .loss_fn = loss_fn
227231
228232 if dynamic :
229233 self .fake_mode .shape_env = ShapeEnv ()
@@ -294,11 +298,29 @@ def build_model_graph(self):
294298 inputs = self .input_fn ()
295299 if not isinstance (inputs , tuple ):
296300 inputs = (inputs ,)
301+ model_wrapper : Callable
302+ if self .loss_fn is not None :
303+
304+ def model_with_loss (inputs , target ) -> Any :
305+ if not isinstance (inputs , tuple ):
306+ inputs = (inputs ,)
307+ output = self .model (* inputs )
308+ loss = self .loss_fn (output , target ) # type: ignore[misc]
309+ return loss
310+
311+ model_wrapper = model_with_loss
312+ else :
313+
314+ def model_wo_loss (* inputs ) -> Any :
315+ output = self .model (* inputs )
316+ return output
317+
318+ model_wrapper = model_wo_loss
297319
298320 with set_dtype_cast (
299321 True
300322 ), enable_local_map_wrapping (), torch ._dynamo .utils ._disable_saved_tensors_hooks_during_tracing ():
301- torch_ir_with_fqn = _export (self .model , inputs )
323+ torch_ir_with_fqn = _export (self .model , model_wrapper , inputs )
302324 # TODO Cna't use fake mode here because it clashes with the user level
303325 # fake mode. Ideally dynamo should reuse the user level fake mode.
304326 self .joint_with_descriptors = aot_export_joint_with_descriptors (
@@ -326,6 +348,7 @@ def build_model_graph(self):
326348 print_output = False , include_stride = True , include_device = True
327349 ),
328350 )
351+ print (gm .graph )
329352
330353 self .gm = gm
331354
0 commit comments