55
66import logging
77from dataclasses import dataclass
8- from typing import Any , Callable , Optional , Union , cast
8+ from typing import Any , Callable , cast , Optional , Union
99
1010import torch
1111import torch .fx as fx
12+
13+ from autoparallel .utils import DebugInterpreter
1214from torch .distributed .pipelining .schedules import (
1315 _Action ,
1416 _PipelineContext ,
1517 _PipelineScheduleRuntime ,
1618 _wait_batch_p2p ,
1719)
1820from torch .distributed .pipelining .stage import (
19- PipelineStage ,
2021 _normalize_model_output_as_tuple ,
22+ PipelineStage ,
2123)
2224from torch .distributed .tensor import DTensor
2325
24- from autoparallel .utils import DebugInterpreter
25-
2626logger = logging .getLogger (__name__ )
2727logger .setLevel (logging .DEBUG )
2828
@@ -275,6 +275,11 @@ def stage_forward(
275275 # Receive activations for this chunk
276276 # Activations only come in args form
277277 composite_args = stage ._retrieve_recv_activations (mb_index )
278+ if stage .is_last and ctx .target_mbs is not None :
279+ assert isinstance (
280+ composite_args , tuple
281+ ), f"Expected composite args to be a tuple but got { type (composite_args )} "
282+ composite_args = composite_args + (ctx .target_mbs [mb_index ],) # type: ignore[index]
278283
279284 # stage._validate_fwd_input(args, kwargs) Maybe need to validate composite args?
280285 logger .debug (
@@ -292,6 +297,8 @@ def stage_forward(
292297 # Output chunks is only used for the last stage since we only merge the output of the last stage
293298 if stage .is_last :
294299 stage .output_chunks .append (output )
300+ if ctx .target_mbs is not None :
301+ ctx .schedule_ref ._internal_losses .append (output )
295302
296303 stage .fwd_cache [mb_index ] = (
297304 output_tuple , # stage_output
@@ -360,7 +367,7 @@ def stage_full_backward(
360367 # HACK till we have loss function, we populate the tangents here manually
361368 bwd_kwargs = {
362369 "stage_output" : loss ,
363- "tangents" : [torch .randn_like (stage_output [0 ])],
370+ "tangents" : [torch .ones_like (stage_output [0 ])],
364371 "saved_intermediates" : saved_intermediates ,
365372 }
366373 else :
@@ -525,7 +532,9 @@ def _accumulate_stage_grads_and_clear_states(
525532 stage .state .clear ()
526533
527534 def step (self , * args , ** kwargs ) -> None :
528-
535+ has_targets_and_loss = (
536+ "losses" in kwargs and "targets" in kwargs if kwargs else False
537+ )
529538 for stage in self .schedule ._stages :
530539 assert isinstance (stage , GraphPipelineStage )
531540 self ._populate_stage_states (stage )
@@ -535,3 +544,11 @@ def step(self, *args, **kwargs) -> None:
535544 for stage in self .schedule ._stages :
536545 assert isinstance (stage , GraphPipelineStage )
537546 self ._accumulate_stage_grads_and_clear_states (stage )
547+ if stage .is_last and has_targets_and_loss :
548+ losses = kwargs ["losses" ]
549+ losses .clear ()
550+ assert (
551+ len (self .schedule ._internal_losses ) == self .schedule ._n_microbatches
552+ )
553+ losses .extend (self .schedule ._internal_losses )
554+ self .schedule ._internal_losses .clear ()
0 commit comments